Jekyll2020-09-01T17:06:12+01:00https://benanne.github.io/feed.xmlSander DielemanI write about machine learning, deep learning, music information retrieval, recommender systems, generative models and more.Addendum: quantifying our flawed intuitions2020-09-01T00:00:00+01:002020-09-01T00:00:00+01:00https://benanne.github.io/2020/09/01/typicality-addendum<p>This post is an addendum to <a href="/2020/09/01/typicality.html">my blog post about typicality</a>. Please consider reading that first, if you haven’t already. Here, I will try to quantify what happens when our intuitions fail us in high-dimensional spaces.</p>
<p><em>Note that the practical relevance of this is limited, so consider this a piece of optional extra content!</em></p>
<p>In the ‘unfair coin flips’ example from the main blog post, it’s actually pretty clear what happens when our intuitions fail us: we think of the binomial distribution, <strong>ignoring the order of the sequences as a factor, when we should actually be taking it into account</strong>. Referring back to the table from section 2.1, we use the probabilities in the rightmost column, when we should be using those in the third column. But when we think of a high-dimensional Gaussian distribution and come to the wrong conclusion, what distribution are we <em>actually</em> thinking of?</p>
<h2 id="the-gaussian-distribution-mathcaln_k">The Gaussian distribution \(\mathcal{N}_K\)</h2>
<figure>
<img src="/images/bubbles.jpg" />
</figure>
<p>Let’s start by quantifying what a multivariate Gaussian distribution actually looks like: let \(\mathbf{x} \sim \mathcal{N}(\mathbf{0}, I_K)\), a standard Gaussian distribution in \(K\) dimensions, henceforth referred to as \(\mathcal{N}_K\). We can sample from it by drawing \(K\) independent one-dimensional samples \(x_i \sim \mathcal{N}(0, 1)\), and joining them into a vector \(\mathbf{x}\). This distribution is <strong>spherically symmetric</strong>, which makes it very natural to think about samples in terms of their <strong>distance to the mode</strong> (in this case, the origin, corresponding to the zero-vector \(\mathbf{0}\)), because all samples at a given distance \(r\) have the same density.</p>
<p>Now, let’s look at the distribution of \(r\): it seems as if the multivariate Gaussian distribution \(\mathcal{N}_K\) naturally arises by taking a univariate version of it, and rotating it around the mode in every possible direction in \(K\)-dimensional space. Because each of these individual rotated copies is Gaussian, this in turn might seem to imply that the distance from the mode \(r\) is itself Gaussian (or rather half-Gaussian, since it is a nonnegative quantity). But this is incorrect! \(r\) actually follows a <a href="https://en.wikipedia.org/wiki/Chi_distribution"><strong>chi distribution</strong></a> with \(K\) degrees of freedom: \(r \sim \chi_K\).</p>
<p>Note that for \(K = 1\), this does indeed correspond to a half-Gaussian distribution. But as \(K\) increases, the mode of the chi distribution rapidly shifts away from 0: it actually sits at \(\sqrt{K - 1}\). This leaves considerably less probability mass near 0, where the mode of our original multivariate Gaussian \(\mathcal{N}_K\) is located.</p>
<p>This exercise yields an alternative sampling strategy for multivariate Gaussians: first, sample a distance from the mode \(r \sim \chi_K\). Then, sample a direction, i.e. a vector on the \(K\)-dimensional unit sphere \(S^K\), uniformly at random: \(\mathbf{\theta} \sim U[S^K]\). Multiply them together to obtain a Gaussian sample: \(\mathbf{x} = r \cdot \mathbf{\theta} \sim \mathcal{N}_K\).</p>
<h2 id="the-gaussian-mirage-distribution-mathcalm_k">The Gaussian mirage distribution \(\mathcal{M}_K\)</h2>
<figure>
<img src="/images/mirage.jpg" />
</figure>
<p>What if, instead of sampling \(r \sim \chi_K\), we sampled \(r \sim \mathcal{N}(0, K)\) instead? Note that \(\sigma^2_{\chi_K} = K\), so this change preserves the scale of the resulting vectors. For \(K = 1\), we get the same distribution for \(\mathbf{x}\), but for \(K > 1\), we get something very different. The resulting distribution represents what we might think the multivariate Gaussian distribution looks like, if we rely on a mistaken intuition and squint a bit. Let’s call this the <strong>Gaussian mirage</strong> distribution, denoted by \(\mathcal{M}\): \(\mathbf{x} = r \cdot \mathbf{\theta} \sim \mathcal{M}_K\). (If this thing already has a name, I’m not aware of it, so please let me know!)</p>
<p>We’ve already established that \(\mathcal{M}_1 \equiv \mathcal{N}_1\). But in higher dimensions, these distributions behave very differently. One way to comprehend this is to look at a flattened histogram of samples across all coordinates:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="k">def</span> <span class="nf">gaussian</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">mirage</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">):</span>
<span class="n">direction</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span>
<span class="n">direction</span> <span class="o">/=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">direction</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
<span class="n">distance</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">k</span><span class="p">),</span> <span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="n">distance</span> <span class="o">*</span> <span class="n">direction</span>
<span class="k">def</span> <span class="nf">plot_histogram</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="n">plt</span><span class="p">.</span><span class="n">hist</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">ravel</span><span class="p">(),</span> <span class="n">bins</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">80000</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlim</span><span class="p">(</span><span class="o">-</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">labelleft</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">left</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">labelbottom</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">bottom</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">ks</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">100</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">ks</span><span class="p">):</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">ks</span><span class="p">),</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">f'K = </span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>
<span class="n">plot_histogram</span><span class="p">(</span><span class="n">gaussian</span><span class="p">(</span><span class="mi">10</span><span class="o">**</span><span class="mi">6</span> <span class="o">//</span> <span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">ks</span><span class="p">),</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">+</span> <span class="nb">len</span><span class="p">(</span><span class="n">ks</span><span class="p">))</span>
<span class="n">plot_histogram</span><span class="p">(</span><span class="n">mirage</span><span class="p">(</span><span class="mi">10</span><span class="o">**</span><span class="mi">6</span> <span class="o">//</span> <span class="n">k</span><span class="p">,</span> <span class="n">k</span><span class="p">))</span>
</code></pre></div></div>
<figure>
<a href="/images/gaussian_histograms.png"><img src="/images/gaussian_histograms.png" alt="Histograms of the flattened coordinates of the multivariate Gaussian distribution (top) and the Gaussian mirage (bottom)." /></a>
<figcaption>Histograms of the flattened coordinates of the multivariate Gaussian distribution (top) and the Gaussian mirage (bottom), for different dimensionalities (K). For the mirage, the histograms become increasingly peaked around 0 as the dimensionality increases.</figcaption>
</figure>
<p>For \(\mathcal{N}_K\), this predictably looks like a univariate Gaussian for all \(K\). For \(\mathcal{M}_K\), it becomes highly <a href="https://en.wikipedia.org/wiki/Kurtosis">leptokurtic</a> as \(K\) increases, indicating that <strong>dramatically more probability mass is located close to the mode</strong>.</p>
<h2 id="typical-sets-of-mathcaln_k-and-mathcalm_k">Typical sets of \(\mathcal{N}_K\) and \(\mathcal{M}_K\)</h2>
<p>Let’s also look at the typical sets for both of these distributions. For \(\mathcal{N}_K\), the probability density function (pdf) has the form:</p>
\[f_{\mathcal{N}_K}(\mathbf{x}) = (2 \pi)^{-\frac{K}{2}} \exp \left( -\frac{\mathbf{x}^T \mathbf{x}}{2} \right),\]
<p>and the differential entropy is given by:</p>
\[H_{\mathcal{N}_K} = \frac{K}{2} \log \left(2 \pi e \right) .\]
<p>To find the typical set, we just need to look for the \(\mathbf{x}\) where \(f_{\mathcal{N}_K}(\mathbf{x}) \approx 2^{-H_{\mathcal{N}_K}} = (2 \pi e)^{-\frac{K}{2}}\) (assuming the entropy is measured in bits). This is clearly the case when \(\mathbf{x}^T\mathbf{x} \approx K\), or in other words, for <strong>any \(\mathbf{x}\) whose distance from the mode is close to \(\sqrt{K}\)</strong>. This is the <em>Gaussian annulus</em> from before.</p>
<p>Let’s subject the Gaussian mirage \(\mathcal{M}_K\) to the same treatment. It’s not obvious how to express the pdf in terms of \(\mathbf{x}\), but it’s easier if we rewrite \(\mathbf{x}\) as \(r \cdot \mathbf{\theta}\), as before, and imagine the sampling procedure: first, pick a radius \(r \sim \mathcal{HN}(0, K)\) (the half-Gaussian distribution — using the Gaussian distribution complicates the math a bit, because the radius should be nonnegative), and then pick a position on the \(K\)-sphere with radius \(r\), uniformly at random:</p>
\[f_{\mathcal{M}_K}(\mathbf{x}) = f_{\mathcal{HN}(0, K)}(r) \cdot f_{U[S^K(r)]}(\theta) = \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \cdot \frac{1}{r^{K-1}} \frac{\Gamma\left( \frac{K}{2} \right)}{2 \pi ^ \frac{K}{2}} .\]
<p>The former factor is the density of the half-Gaussian distribution: note the additional factor 2 compared to the standard Gaussian density, because we only consider nonnegative values of \(r\). The latter is the density of a uniform distribution on the \(K\)-sphere with radius \(r\) (which is the inverse of its surface area). As an aside, this factor is worth taking a closer look at, because it behaves in a rather peculiar way. Here’s the surface area of a unit \(K\)-sphere for increasing \(K\):</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">scipy.special</span>
<span class="n">K</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">30</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="o">**</span><span class="p">(</span><span class="n">K</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">))</span> <span class="o">/</span> <span class="n">scipy</span><span class="p">.</span><span class="n">special</span><span class="p">.</span><span class="n">gamma</span><span class="p">(</span><span class="n">K</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">stem</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">basefmt</span><span class="o">=</span><span class="s">' '</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">35</span><span class="p">)</span>
</code></pre></div></div>
<figure>
<a href="/images/sphere_area.png"><img src="/images/sphere_area.png" alt="Surface area of a K-dimensional unit sphere, for K ranging from 0 to 30." /></a>
<figcaption>Surface area of a K-dimensional unit sphere, for K ranging from 0 to 30.</figcaption>
</figure>
<p>Confused? You and me both! Believe it or not, <strong>the surface area of a \(K\)-sphere tends to zero with increasing \(K\)</strong> — but only after growing to a maximum at \(K = 7\) first. <a href="https://math.stackexchange.com/questions/67039/why-does-the-volume-of-the-unit-sphere-go-to-zero">High-dimensional spaces are <em>weird</em></a>.</p>
<p>Another thing worth noting is that the density at the mode \(f_{\mathcal{M}_K}(\mathbf{0}) = +\infty\) for \(K > 1\), which already suggests that this distribution has a lot of its mass concentrated near the mode.</p>
<p>Computing the entropy of this distribution takes a bit of work. The differential entropy is:</p>
\[H_{\mathcal{M}_K} = - \int_{\mathbb{R}^K} f_{\mathcal{M}_K}(\mathbf{x}) \log f_{\mathcal{M}_K}(\mathbf{x}) \mathrm{d}\mathbf{x} .\]
<p>We can use the radial symmetry of this density to reformulate this as an integral of a scalar function:</p>
\[H_{\mathcal{M}_K} = - \int_0^{+\infty} f_{\mathcal{M}_K}(r) \log f_{\mathcal{M}_K}(r) S^K(r) \mathrm{d} r,\]
<p>where \(S^K(r)\) is the surface area of a \(K\)-sphere with radius \(r\). Filling in the density function, we get:</p>
\[H_{\mathcal{M}_K} = - \int_0^{+\infty} \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \cdot \log \left( \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \cdot \frac{1}{r^{K-1}} \frac{\Gamma\left( \frac{K}{2} \right)}{2 \pi ^ \frac{K}{2}} \right) \mathrm{d} r,\]
<p>where we have made use of the fact that \(S^K(r)\) cancels out with the second factor of \(f_{\mathcal{M}_K}(r)\). We can split up the \(\log\) into three different terms, \(H_{\mathcal{M}_K} = H_1 + H_2 + H_3\):</p>
\[H_1 = - \int_0^{+\infty} \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \left(-\frac{r^2}{2 K} \right) \mathrm{d} r = \int_0^{+\infty} \frac{r^2}{\sqrt{2 \pi}} \exp \left( -\frac{r^2}{2} \right) \mathrm{d} r = \frac{1}{2},\]
\[H_2 = - \int_0^{+\infty} \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \log \left( \frac{1}{r^{K-1}} \right) \mathrm{d} r = \frac{K - 1}{2} \left( \log \frac{K}{2} - \gamma \right),\]
\[H_3 = - \int_0^{+\infty} \frac{2}{\sqrt{2 \pi K}} \exp \left( -\frac{r^2}{2 K} \right) \log \left( \frac{2}{\sqrt{2 \pi K}} \frac{\Gamma\left( \frac{K}{2} \right)}{2 \pi ^ \frac{K}{2}} \right) \mathrm{d} r = - \log \left( \frac{1}{\sqrt{2 \pi K}} \frac{\Gamma\left( \frac{K}{2} \right)}{\pi ^ \frac{K}{2}} \right),\]
<p>where we have taken \(\log\) to be the natural logarithm for convenience, and \(\gamma\) is the <a href="https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant">Euler-Mascheroni constant</a>. In summary:</p>
\[H_{\mathcal{M}_K} = \frac{1}{2} + \frac{K - 1}{2} \left( \log \frac{K}{2} - \gamma \right) - \log \left( \frac{1}{\sqrt{2 \pi K}} \frac{\Gamma\left( \frac{K}{2} \right)}{\pi ^ \frac{K}{2}} \right) .\]
<p>Note that \(H_{\mathcal{M}_1} = \frac{1}{2} \log (2 \pi e)\), matching the standard Gaussian distribution as expected.</p>
<p>Because this is measured in nats, not in bits, we find the typical set where \(f_{\mathcal{M}_K}(\mathbf{x}) \approx \exp(-H_{\mathcal{M}_K})\). We must find \(r \geq 0\) so that</p>
\[\frac{r^2}{2 K} + (K - 1) \log r = \frac{1}{2} + \frac{K - 1}{2} \left( \log \frac{K}{2} - \gamma \right) .\]
<p>We can express the solution of this equation in terms of the Lambert \(W\) function:</p>
\[r = \sqrt{K (K - 1) W\left(\frac{1}{K (K - 1)} \exp \left( \frac{1}{K - 1} + \log \frac{K}{2} - \gamma \right) \right)} .\]
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">scipy.special</span>
<span class="n">K</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">unique</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">round</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">logspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">100</span><span class="p">)))</span>
<span class="n">w_arg</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">K</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">K</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="n">euler_gamma</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">K</span> <span class="o">*</span> <span class="p">(</span><span class="n">K</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">r</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">K</span> <span class="o">*</span> <span class="p">(</span><span class="n">K</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">scipy</span><span class="p">.</span><span class="n">special</span><span class="p">.</span><span class="n">lambertw</span><span class="p">(</span><span class="n">w_arg</span><span class="p">))</span>
<span class="n">r</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Special case for K = 1.
</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="n">r</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">K</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xscale</span><span class="p">(</span><span class="s">'log'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'$K$'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'$</span><span class="se">\\</span><span class="s">frac{r}{</span><span class="se">\\</span><span class="s">sqrt{K}}$'</span><span class="p">)</span>
</code></pre></div></div>
<figure>
<a href="/images/mirage_radius.png"><img src="/images/mirage_radius.png" alt="The distance from the mode at which the typical set of the Gaussian mirage is found, as a function of K." /></a>
<figcaption>The distance from the mode at which the typical set of the Gaussian mirage is found, normalised by the standard deviation, as a function of K.</figcaption>
</figure>
<p>As \(K \to +\infty\), this seems to converge to the value \(0.52984 \sqrt{K}\), which is somewhere in between the mode (\(0\)) and the mean (\(\sqrt{\frac{2K}{\pi}} \approx 0.79788 \sqrt{K}\)) of the half-Gaussian distribution (which \(r\) follows by construction). This is not just an interesting curiosity: although it is clear that the typical set of \(\mathcal{M}_K\) is much closer to the mode than for \(\mathcal{N}_K\) (because \(r < \sqrt{K}\)), the mode is not unequivocally a member of the typical set. In fact, the definition of typical sets sort of breaks down for this distribution, because we need to allow for a very large range of probability densities to capture the bulk of its mass. In this sense, it behaves a lot more like the one-dimensional Gaussian. Nevertheless, even this strange concoction of a distribution exhibits unintuitive behaviour in high-dimensional space!</p>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2020typicality,
author = {Dieleman, Sander},
title = {Musings on typicality},
url = {https://benanne.github.io/2020/09/01/typicality.html},
year = {2020}
}
</code></pre></div></div>This post is an addendum to my blog post about typicality. Please consider reading that first, if you haven’t already. Here, I will try to quantify what happens when our intuitions fail us in high-dimensional spaces.Musings on typicality2020-09-01T00:00:00+01:002020-09-01T00:00:00+01:00https://benanne.github.io/2020/09/01/typicality<p>If you’re training or sampling from generative models, <strong>typicality</strong> is a concept worth understanding. It sheds light on why beam search doesn’t work for autoregressive models of images, audio and video; why you can’t just threshold the likelihood to perform anomaly detection with generative models; and why high-dimensional Gaussians are “soap bubbles”. This post is a summary of my current thoughts on the topic.</p>
<p>First, some context: one of the reasons I’m writing this, is to structure my own thoughts about typicality and the unintuitive behaviour of high-dimensional probability distributions. Most of these thoughts have not been empirically validated, and several are <strong>highly speculative</strong> and could be wrong. Please bear this in mind when reading, and don’t hesitate to use the comments section to correct me. Another reason is to draw more attention to the concept, as I’ve personally found it extremely useful to gain insight into the behaviour of generative models, and to correct some of my flawed intuitions. I <a href="https://twitter.com/sedielem/status/1264587646321516544">tweeted</a> about typicality a few months ago, but as it turns out, I have a lot more to say on the topic!</p>
<p>As with most of my blog posts, I will assume a degree of familiarity with machine learning. For certain parts, some knowledge of generative modelling is probably useful as well. <a href="https://benanne.github.io/2020/03/24/audio-generation.html#generative-models">Section 3 of my previous blog post</a> provides an overview of generative models.</p>
<p><strong>Overview</strong> (click to scroll to each section):</p>
<ol>
<li><em><a href="#likelihood">The joys of likelihood</a></em></li>
<li><em><a href="#examples">Motivating examples</a></em></li>
<li><em><a href="#abstraction">Abstraction and the curse of dimensionality</a></em></li>
<li><em><a href="#typicality">Typicality</a></em></li>
<li><em><a href="#in-the-wild">Typicality in the wild</a></em></li>
<li><em><a href="#right-level">The right level of abstraction</a></em></li>
<li><em><a href="#closing-thoughts">Closing thoughts</a></em></li>
<li><em><a href="#acknowledgements">Acknowledgements</a></em></li>
<li><em><a href="#references">References</a></em></li>
</ol>
<h2 id="-the-joys-of-likelihood"><a name="likelihood"></a> The joys of likelihood</h2>
<p>When it comes to generative modelling, my personal preference for the <strong>likelihood-based paradigm</strong> is no secret (my recent foray into <a href="https://deepmind.com/research/publications/End-to-End-Adversarial-Text-to-Speech">adversarial methods for text-to-speech</a> notwithstanding). While there are many other ways to build and train models (e.g. using adversarial networks, score matching, optimal transport, quantile regression, … see <a href="https://benanne.github.io/2020/03/24/audio-generation.html#generative-models">my previous blog post</a> for an overview), there is something intellectually pleasing about the simplicity of maximum likelihood training: the model explicitly parameterises a probability distribution, and we fit the parameters of that distribution so it is able to explain the observed data as well as possible (i.e., assigns to it the highest possible likelihood).</p>
<p>It turns out that this is far from the whole story, and <strong>‘<em>higher likelihood</em>’ doesn’t always mean <em>better</em> in a way that we actually care about</strong>. In fact, the way likelihood behaves in relation to the quality of a model as measured by humans (e.g. by inspecting samples) can be deeply unintuitive. This has been well-known in the machine learning community for some time, and Theis et al.’s <a href="https://arxiv.org/abs/1511.01844"><em>A note on the evaluation of generative models</em></a><sup id="fnref:anote" role="doc-noteref"><a href="#fn:anote" class="footnote">1</a></sup> does an excellent job of demonstrating this with clever thought experiments and concrete examples. In what follows, I will expound on what I think is going on when likelihoods disagree with our intuitions.</p>
<p>One particular way in which a higher likelihood can correspond to a worse model is through <strong>overfitting</strong> on the training set. Because overfitting is ubiquitous in machine learning research, the unintuitive behaviours of likelihood are often incorrectly ascribed to this phenomenon. In this post, I will assume that overfitting is not an issue, and that we are talking about properly regularised models trained on large enough datasets.</p>
<h2 id="-motivating-examples"><a name="examples"></a> Motivating examples</h2>
<h3 id="unfair-coin-flips">Unfair coin flips</h3>
<figure>
<img src="/images/coins.jpg" />
</figure>
<p><a href="https://www.jessicayung.com/counterintuitive-probabilities-typical-sets-from-information-theory/">Jessica Yung has a great blog post</a> that demonstrates how even the simplest of probability distributions start behaving in unintuitive ways in higher-dimensional spaces, and she links this to the concept of typicality. I will borrow her example here and expand on it a bit, but I recommend reading the original post.</p>
<p>To summarise: suppose you have an unfair coin that lands on heads 3 times out of 4. If you toss this coin 16 times, you would expect to see 12 heads (<code class="language-plaintext highlighter-rouge">H</code>) and 4 tails (<code class="language-plaintext highlighter-rouge">T</code>) on average. Of course you wouldn’t expect to see exactly 12 heads and 4 tails every time: there’s a pretty good chance you’d see 13 heads and 3 tails, or 11 heads and 5 tails. Seeing 16 heads and no tails would be quite surprising, but it’s not implausible: in fact, it will happen about 1% of the time. Seeing all tails seems like it would be a miracle. Nevertheless, each coin toss is independent, so even this has a non-zero probability of being observed.</p>
<p>When we count the number of heads and tails in the observed sequence, we’re looking at the <strong><a href="https://en.wikipedia.org/wiki/Binomial_distribution">binomial distribution</a></strong>. We’ve made the implicit assumption that what we care about is the <strong>frequency of occurrence of both outcomes, and not the order in which they occur</strong>. We’ve made <em>abstraction</em> of the order, and we are effectively treating the sequences as unordered sets, so that <code class="language-plaintext highlighter-rouge">HTHHTHHHHTTHHHHH</code> and <code class="language-plaintext highlighter-rouge">HHHHHTHTHHHTHTHH</code> are basically the same thing. That is often desirable, but it’s worth being aware of such assumptions, and making them explicit.</p>
<p><strong>If we do not ignore the order, and ask which sequence is the most likely, the answer is ‘all heads’.</strong> That may seem surprising at first, because seeing only heads is a relatively rare occurrence. But note that we’re asking a different question here, about the ordered sequences themselves, rather than about their statistics. While the difference is pretty clear here, the implicit assumptions and abstractions that we tend to use in our reasoning are often more subtle.</p>
<p>The table and figure below show how the probability of observing a given number of heads and tails can be found by multiplying the probability of a particular sequence with the number of such sequences. Note that ‘all heads’ has the highest probability out of all sequences (bolded), but there is only a single such sequence. The most likely number of heads we’ll observe is 12 (also bolded): even though each individual sequence with 12 heads is less likely, there are a lot more of them, and this second factor ends up dominating.</p>
<table>
<thead>
<tr>
<th style="text-align: center">#H</th>
<th style="text-align: center">#T</th>
<th style="text-align: center">p(sequence)</th>
<th style="text-align: center"># sequences</th>
<th style="text-align: center">p(#H, #T)</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: center">0</td>
<td style="text-align: center">16</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^0 \left(\frac{1}{4}\right)^{16} = 2.33 \cdot 10^{-10}\)</td>
<td style="text-align: center">1</td>
<td style="text-align: center">\(2.33\cdot 10^{-10}\)</td>
</tr>
<tr>
<td style="text-align: center">1</td>
<td style="text-align: center">15</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^1 \left(\frac{1}{4}\right)^{15} = 6.98 \cdot 10^{-10}\)</td>
<td style="text-align: center">16</td>
<td style="text-align: center">\(1.12\cdot 10^{-8}\)</td>
</tr>
<tr>
<td style="text-align: center">2</td>
<td style="text-align: center">14</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^2 \left(\frac{1}{4}\right)^{14} = 2.10 \cdot 10^{-9}\)</td>
<td style="text-align: center">120</td>
<td style="text-align: center">\(2.51\cdot 10^{-7}\)</td>
</tr>
<tr>
<td style="text-align: center">3</td>
<td style="text-align: center">13</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^3 \left(\frac{1}{4}\right)^{13} = 6.29 \cdot 10^{-9}\)</td>
<td style="text-align: center">560</td>
<td style="text-align: center">\(3.52\cdot 10^{-6}\)</td>
</tr>
<tr>
<td style="text-align: center">4</td>
<td style="text-align: center">12</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^4 \left(\frac{1}{4}\right)^{12} = 1.89 \cdot 10^{-8}\)</td>
<td style="text-align: center">1820</td>
<td style="text-align: center">\(3.43\cdot 10^{-5}\)</td>
</tr>
<tr>
<td style="text-align: center">5</td>
<td style="text-align: center">11</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^5 \left(\frac{1}{4}\right)^{11} = 5.66 \cdot 10^{-8}\)</td>
<td style="text-align: center">4368</td>
<td style="text-align: center">\(2.47\cdot 10^{-4}\)</td>
</tr>
<tr>
<td style="text-align: center">6</td>
<td style="text-align: center">10</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^6 \left(\frac{1}{4}\right)^{10} = 1.70 \cdot 10^{-7}\)</td>
<td style="text-align: center">8008</td>
<td style="text-align: center">\(1.36\cdot 10^{-3}\)</td>
</tr>
<tr>
<td style="text-align: center">7</td>
<td style="text-align: center">9</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^7 \left(\frac{1}{4}\right)^9 = 5.09 \cdot 10^{-7}\)</td>
<td style="text-align: center">11440</td>
<td style="text-align: center">\(5.83\cdot 10^{-3}\)</td>
</tr>
<tr>
<td style="text-align: center">8</td>
<td style="text-align: center">8</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^8 \left(\frac{1}{4}\right)^8 = 1.53 \cdot 10^{-6}\)</td>
<td style="text-align: center">12870</td>
<td style="text-align: center">\(1.97\cdot 10^{-2}\)</td>
</tr>
<tr>
<td style="text-align: center">9</td>
<td style="text-align: center">7</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^9 \left(\frac{1}{4}\right)^7 = 4.58 \cdot 10^{-6}\)</td>
<td style="text-align: center">11440</td>
<td style="text-align: center">\(5.24\cdot 10^{-2}\)</td>
</tr>
<tr>
<td style="text-align: center">10</td>
<td style="text-align: center">6</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{10} \left(\frac{1}{4}\right)^6 = 1.37 \cdot 10^{-5}\)</td>
<td style="text-align: center">8008</td>
<td style="text-align: center">\(1.10\cdot 10^{-1}\)</td>
</tr>
<tr>
<td style="text-align: center">11</td>
<td style="text-align: center">5</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{11} \left(\frac{1}{4}\right)^5 = 4.12 \cdot 10^{-5}\)</td>
<td style="text-align: center">4368</td>
<td style="text-align: center">\(1.80\cdot 10^{-1}\)</td>
</tr>
<tr>
<td style="text-align: center">12</td>
<td style="text-align: center">4</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{12} \left(\frac{1}{4}\right)^4 = 1.24 \cdot 10^{-4}\)</td>
<td style="text-align: center">1820</td>
<td style="text-align: center">\(\mathbf{2.25\cdot 10^{-1}}\)</td>
</tr>
<tr>
<td style="text-align: center">13</td>
<td style="text-align: center">3</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{13} \left(\frac{1}{4}\right)^3 = 3.71 \cdot 10^{-4}\)</td>
<td style="text-align: center">560</td>
<td style="text-align: center">\(2.08\cdot 10^{-1}\)</td>
</tr>
<tr>
<td style="text-align: center">14</td>
<td style="text-align: center">2</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{14} \left(\frac{1}{4}\right)^2 = 1.11 \cdot 10^{-3}\)</td>
<td style="text-align: center">120</td>
<td style="text-align: center">\(1.34\cdot 10^{-1}\)</td>
</tr>
<tr>
<td style="text-align: center">15</td>
<td style="text-align: center">1</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{15} \left(\frac{1}{4}\right)^1 = 3.33 \cdot 10^{-3}\)</td>
<td style="text-align: center">16</td>
<td style="text-align: center">\(5.35\cdot 10^{-2}\)</td>
</tr>
<tr>
<td style="text-align: center">16</td>
<td style="text-align: center">0</td>
<td style="text-align: center">\(\left(\frac{3}{4}\right)^{16} \left(\frac{1}{4}\right)^0 = \mathbf{1.00 \cdot 10^{-2}}\)</td>
<td style="text-align: center">1</td>
<td style="text-align: center">\(1.00\cdot 10^{-2}\)</td>
</tr>
</tbody>
</table>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">scipy.special</span>
<span class="n">h</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">16</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">p_sequence</span> <span class="o">=</span> <span class="p">(</span><span class="mi">3</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">**</span><span class="n">h</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">**</span><span class="p">(</span><span class="mi">16</span> <span class="o">-</span> <span class="n">h</span><span class="p">)</span>
<span class="n">num_sequences</span> <span class="o">=</span> <span class="n">scipy</span><span class="p">.</span><span class="n">special</span><span class="p">.</span><span class="n">comb</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
<span class="n">p_heads_count</span> <span class="o">=</span> <span class="n">p_sequence</span> <span class="o">*</span> <span class="n">num_sequences</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">p_sequence</span><span class="p">,</span> <span class="s">'C0-s'</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s">'probability of a single sequence with this number of heads'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">p_heads_count</span><span class="p">,</span> <span class="s">'C1-o'</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s">'probability of observing this number of heads'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">yscale</span><span class="p">(</span><span class="s">'log'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'number of heads'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'probability'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
</code></pre></div></div>
<figure>
<a href="/images/unfair_coin_probs.png"><img src="/images/unfair_coin_probs.png" alt="Probabilities of observing a particular sequence with a given number of heads, and of observing a given number of heads." /></a>
<figcaption>Probabilities of observing a particular sequence with a given number of heads, and of observing a given number of heads.</figcaption>
</figure>
<h3 id="gaussian-soap-bubbles">Gaussian soap bubbles</h3>
<figure>
<img src="/images/bubbles.jpg" />
</figure>
<p>Another excellent blog post about the unintuitive behaviour of high-dimensional probability distributions is Ferenc Huszar’s <a href="https://www.inference.vc/high-dimensional-gaussian-distributions-are-soap-bubble/">‘Gaussian Distributions are Soap Bubbles’</a>. A one-dimensional Gaussian looks like bell curve: a big bump around the mode, with a tail on either side. Clearly, the bulk of the total probability mass is clumped together around the mode. In higher-dimensional spaces, this shape changes completely: the bulk of the probability mass of a spherical Gaussian distribution with unit variance in \(K\) dimensions is <strong>concentrated in a thin ‘shell’ at radius \(\sqrt{K}\)</strong>. This is known as the <em>Gaussian annulus theorem</em>.</p>
<p>For example, if we sample lots of vectors from a 100-dimensional standard Gaussian, and measure their radii, we will find that just over 84% of them are between 9 and 11, and more than 99% are between 8 and 12. Only about 0.2% have a radius smaller than 8!</p>
<p>Ferenc points out an interesting implication: <strong>high-dimensional Gaussians are very similar to uniform distributions on the sphere</strong>. This clearly isn’t true for the one-dimensional case, but it turns out that’s an exception, not the rule. Stefan Stein also discusses this implication in more detail in <a href="https://stefan-stein.github.io/posts/2020-03-07-concentration-properties-of-high-dimensional-normal-distributions/">a recent blog post</a>.</p>
<p>Where our intuition can go wrong here, is that we might underestimate how quickly a high-dimensional space grows in size as we move further away from the mode. Because of the radial symmetry of the distribution, we tend to think of all points at a given distance from the mode as similar, and we implicitly group them into sets of concentric spheres. This allows us to revert back to reasoning in one dimension, which we are more comfortable with: we think of a high-dimensional Gaussian as a distribution over these sets, rather than over individual points. What we tend to overlook, is that <strong>those sets differ wildly in size</strong>: as we move away from the mode, they grow larger very quickly. Note that this does not happen at all in 1D!</p>
<h2 id="-abstraction-and-the-curse-of-dimensionality"><a name="abstraction"></a> Abstraction and the curse of dimensionality</h2>
<figure>
<img src="/images/sand.jpg" />
</figure>
<p>The <a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality">curse of dimensionality</a> is a catch-all term for various phenomena that appear very different and often counterintuitive in high-dimensional spaces. It is used to highlight poor scaling behaviour of ideas and algorithms, where one wouldn’t necessarily expect it. In the context of machine learning, it is usually used in a more narrow sense, to refer to the fact that models of high-dimensional data tend to require very large training datasets to be effective. But the curse of dimensionality manifests itself in many forms, and the unintuitive behaviour of high-dimensional probability distributions is just one of them.</p>
<p>In general, humans have lousy intuitions about high-dimensional spaces. But what exactly is going on when we get things wrong about high-dimensional distributions? In both of the motivating examples, the intuition breaks down in a similar way: if we’re not careful, <strong>we might implicitly reason about the probabilities of sets, rather than individual points</strong>, without taking into account their relative sizes, and arrive at the wrong answer. This means that we can encounter this issue for both discrete and continuous distributions.</p>
<p>We can generalise this idea of grouping points into sets of similar points, by thinking of it as <strong>‘abstraction’</strong>: rather than treating each point as a separate entity, we think of it as an instance of a particular <strong>concept</strong>, and ignore its idiosyncrasies. When we think of ‘sand’, we are rarely concerned about the characteristics of each individual grain. Similarly, in the ‘unfair coin flips’ example, we group sequences by their number of heads and tails, ignoring their order. In the case of the high-dimensional Gaussian, the natural grouping of points is based on their Euclidean distance from the mode. A more high-level example is that of natural images, where individual pixel values across localised regions of the image combine to form edges, textures, or even objects. There are usually many combinations of pixel values that give rise to the same texture, and we aren’t able to visually distinguish these particular instances unless we carefully study them side by side.</p>
<p>The following is perhaps a bit of an unfounded generalisation based on my own experience, but our brains seem hardwired to perform this kind of abstraction, so that we can reason about things in the familiar low-dimensional setting. It seems to happen unconsciously and continuously, and bypassing it requires a proactive approach.</p>
<h2 id="-typicality"><a name="typicality"></a> Typicality</h2>
<figure>
<img src="/images/typicality.jpg" />
</figure>
<p>Informally, <strong>typicality</strong> refers to the characteristics that samples from a distribution tend to exhibit on average (in expectation). In the ‘unfair coin flip’ example, a sequence with 12 heads and 4 tails is ‘typical’. A sequence with 6 heads and 10 tails is highly atypical. Typical sequences contain an average amount of information: they are not particularly surprising or (un)informative.</p>
<p>We can <a href="https://en.wikipedia.org/wiki/Typical_set">formalise this intuition</a> using the <a href="https://en.wikipedia.org/wiki/Entropy_(information_theory)">entropy</a> of the distribution: a <strong>typical set</strong> \(\mathcal{T}_\varepsilon \subset \mathcal{X}\) is a set of sequences from \(\mathcal{X}\) whose probability is close to \(2^{-H}\), where \(H\) is the entropy of the distribution that the sequences were drawn from, measured in bits:</p>
\[\mathcal{T}_\varepsilon = \{ \mathbf{x} \in \mathcal{X}: 2^{-(H + \varepsilon)} \leq p(\mathbf{x}) \leq 2^{-(H - \varepsilon)} \} .\]
<p>This means that the negative log likelihood of each such sequence is close to the entropy. Note that a distribution doesn’t have just one typical set: we can define many typical sets based on how close the probability of the sequences contained therein should be to \(2^{-H}\), by choosing different values of \(\varepsilon > 0\).</p>
<p>This concept was originally defined in an information-theoretic context, but I want to focus on machine learning, where I feel it is somewhat undervalued. It is often framed in terms of sequences sampled from <a href="https://en.wikipedia.org/wiki/Stationary_ergodic_process">stationary ergodic processes</a>, but it is useful more generally for distributions of any kind of high-dimensional data points, both continuous and discrete, regardless of whether we tend to think of them as sequences.</p>
<p>Why is this relevant to our discussion of abstraction and flawed human intuitions? As the dimensionality increases, the probability that any random sample from a distribution is part of a given typical set \(\mathcal{T}_\varepsilon\) tends towards 1. In other words, randomly drawn samples will almost always be ‘typical’, and <strong>the typical set covers most of the support of the distribution</strong> (this is a consequence of the so-called <a href="https://en.wikipedia.org/wiki/Asymptotic_equipartition_property">asymptotic equipartition property (AEP)</a>). This happens even when \(\varepsilon\) is relatively small, as long as the dimensionality is high enough. This is visualised for a 100-dimensional standard Gaussian distribution below (based on empirical measurements, to avoid having to calculate some <em>gnarly</em> 100D integrals).</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="n">N</span> <span class="o">=</span> <span class="mi">1000000</span>
<span class="n">K</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">samples</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">))</span>
<span class="n">radii</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">samples</span><span class="o">**</span><span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span>
<span class="n">epsilon</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">logspace</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">200</span><span class="p">)</span>
<span class="n">lo</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">K</span> <span class="o">-</span> <span class="n">epsilon</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mi">4</span><span class="p">),</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">hi</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">K</span> <span class="o">+</span> <span class="n">epsilon</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mi">4</span><span class="p">))</span>
<span class="n">radius_range</span> <span class="o">=</span> <span class="n">hi</span> <span class="o">-</span> <span class="n">lo</span>
<span class="n">mass</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">((</span><span class="n">lo</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o"><</span> <span class="n">radii</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">radii</span> <span class="o"><</span> <span class="n">hi</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">epsilon</span><span class="p">))]</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">radius_range</span><span class="p">,</span> <span class="n">mass</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'Difference between the min. and max. radii inside '</span>
<span class="s">'$</span><span class="se">\\</span><span class="s">mathcal{T}_</span><span class="se">\\</span><span class="s">varepsilon$ for given $</span><span class="se">\\</span><span class="s">varepsilon$'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Total probability mass in $</span><span class="se">\\</span><span class="s">mathcal{T}_</span><span class="se">\\</span><span class="s">varepsilon$'</span><span class="p">)</span>
</code></pre></div></div>
<figure>
<a href="/images/annulus_prob.png"><img src="/images/annulus_prob.png" alt="The total probability mass of a range of typical sets of a 100-dimensional standard Gaussian distribution, with their size measured by the difference between the minimal and maximal radii within the set (i.e. the width of the Gaussian annulus). An annulus with width 4 already contains most of the probability mass." /></a>
<figcaption>The total probability mass of a range of typical sets of a 100-dimensional standard Gaussian distribution, with their size measured by the difference between the minimal and maximal radii within the set (i.e. the width of the Gaussian annulus). An annulus with width 4 already contains most of the probability mass.</figcaption>
</figure>
<p>But this is where it gets interesting: for unimodal high-dimensional distributions, such as the multivariate Gaussian, <strong>the mode</strong> (i.e. the most likely value) <strong>usually isn’t part of the typical set</strong>. More generally, individual samples from high-dimensional (and potentially multimodal) distributions that have an unusually high likelihood are not typical, so we wouldn’t expect to see them when sampling. This can seem paradoxical, because they are by definition very ‘likely’ samples — it’s just that there are so few of them! Think about how surprising it would be to randomly sample the zero vector (or something very close to it) from a 100-dimensional standard Gaussian distribution.</p>
<p>This has some important implications: if we want to learn more about what a high-dimensional distribution looks like, <strong>studying the most likely samples is usually a bad idea</strong>. If we want to obtain a good quality sample from a distribution, subject to constraints, we should not be trying to find the single most likely one. Yet in machine learning, these are things that we do on a regular basis. In the next section, I’ll discuss a few situations where this paradox comes up in practice. For a more mathematical treatment of typicality and the curse of dimensionality, check out <a href="https://mc-stan.org/users/documentation/case-studies/curse-dims.html">this case study by Bob Carpenter</a>.</p>
<h2 id="-typicality-in-the-wild"><a name="in-the-wild"></a> Typicality in the wild</h2>
<figure>
<img src="/images/in_the_wild.jpg" />
</figure>
<p>A significant body of literature, spanning several subfields of machine learning, has sought to interpret and/or mitigate the unintuitive ways in which high-dimensional probability distributions behave. In this section, I want to highlight a few interesting papers and discuss them in relation to the concept of typicality. Note that I’ve made a selection based on what I’ve read recently, and this is not intended to be a comprehensive overview of the literature. In fact, I would appreciate pointers to other related work (papers and blog posts) that I should take a look at!</p>
<h3 id="language-modelling">Language modelling</h3>
<p>In conditional language modelling tasks, such as machine translation or image captioning, it is common to use conditional autoregressive models in combination with heuristic decoding strategies such as <a href="https://en.wikipedia.org/wiki/Beam_search">beam search</a>. The underlying idea is that we want to <strong>find the most likely sentence (i.e. the mode of the conditional distribution, ‘MAP decoding’)</strong>, but since this is intractable, we’ll settle for an approximate result instead.</p>
<p>With typicality in mind, it’s clear that this isn’t necessarily the best idea. Indeed, researchers have found that machine translation results, measured using the <a href="https://en.wikipedia.org/wiki/BLEU">BLEU metric</a>, sometimes get worse when the <em>beam width</em> is increased<sup id="fnref:sixchallenges" role="doc-noteref"><a href="#fn:sixchallenges" class="footnote">2</a></sup> <sup id="fnref:analyzinguncertainty" role="doc-noteref"><a href="#fn:analyzinguncertainty" class="footnote">3</a></sup>. A higher beam width gives a better, more computationally costly approximation to the mode, but not necessarily better translation results. In this case, it’s tempting to blame the metric itself, which obviously isn’t perfect, but this effect has also been observed with human ratings<sup id="fnref:tradeoff" role="doc-noteref"><a href="#fn:tradeoff" class="footnote">4</a></sup>, so that cannot be the whole story.</p>
<p>A <a href="https://arxiv.org/abs/2005.10283">recent paper by Eikema & Aziz</a><sup id="fnref:mapdecoding" role="doc-noteref"><a href="#fn:mapdecoding" class="footnote">5</a></sup> provides an excellent review of recent work in this space, and makes a compelling argument for <strong>MAP decoding as the culprit behind many of the pathologies that neural machine translation systems exhibit</strong> (rather than their network architectures or training methodologies). They also propose an alternative decoding strategy called <em>‘minimum Bayes risk’ (MBR) decoding</em> that takes into account the whole distribution, rather than only the mode.</p>
<p>In unconditional language modelling, beam search hasn’t caught on, but not for want of trying! Stochasticity of the result is often desirable in this setting, and the focus has been on sampling strategies instead. In <a href="https://arxiv.org/abs/1904.09751"><em>The Curious Case of Neural Text Degeneration</em></a><sup id="fnref:degeneration" role="doc-noteref"><a href="#fn:degeneration" class="footnote">6</a></sup>, Holtzman et al. observe that <strong>maximising the probability leads to poor quality results that are often repetitive</strong>. Repetitive samples may not be typical, but they have high likelihoods simply because they are more predictable.</p>
<p>They compare a few different sampling strategies that interpolate between fully random sampling and <em>greedy decoding</em> (i.e. predicting the most likely token at every step in the sequence), including the <em>nucleus sampling</em> technique which they propose. The motivation for trying to find a middle ground is that models will assign low probabilities to sequences that they haven’t seen much during training, which makes <strong>low-probability predictions inherently less reliable</strong>. Therefore, we want to avoid sampling low-probability tokens <em>to some extent</em>.</p>
<p><a href="https://arxiv.org/abs/2004.10450">Zhang et al.</a><sup id="fnref:tradeoff:1" role="doc-noteref"><a href="#fn:tradeoff" class="footnote">4</a></sup> frame the choice of a language model decoding strategy as a trade-off between diversity and quality. However, they find that reducing diversity only helps quality up to a point, and reducing it too much makes the results worse, as judged by human evaluators. They call this <em>‘the likelihood trap’</em>: <strong>human-judged quality of samples correlates very well with likelihood, up to an inflection point, where the correlation becomes negative</strong>.</p>
<p>In the context of typicality, this raises an interesting question: where exactly is this inflection point, and how does it relate to the typical set of the model distribution? I think it would be very interesting to determine whether the inflection point coincides exactly with the typical set, or whether it is more/less likely. Perhaps there is some degree of atypicality that human raters will tolerate? If so, can we quantify it? This wouldn’t be far-fetched: think about our preference for celebrity faces over ‘typical’ human faces, for example!</p>
<h3 id="image-modelling">Image modelling</h3>
<p>The previously mentioned <em>‘note on the evaluation of generative models’</em><sup id="fnref:anote:1" role="doc-noteref"><a href="#fn:anote" class="footnote">1</a></sup> is a seminal piece of work that demonstrates several ways in which likelihoods in the image domain can be vastly misleading.</p>
<p>In <a href="https://arxiv.org/abs/1810.09136"><em>‘Do Deep Generative Models Know What They Don’t Know?’</em></a><sup id="fnref:know" role="doc-noteref"><a href="#fn:know" class="footnote">7</a></sup>, Nalisnick et al. study the behaviour of likelihood-based models when presented with out-of-domain data. They observe how <strong>models can assign higher likelihoods to datasets other than their training datasets</strong>. Crucially, they show this for different classes of likelihood-based models (variational autoencoders, autoregressive models and flow-based models, see Figure 3 in the paper), which clearly demonstrates that this is an issue with the likelihood-based paradigm itself, and not with a particular model architecture or formulation.</p>
<p>Comparing images from CIFAR-10 and SVHN, two of the datasets they use, a key difference is the prevalence of textures in CIFAR-10 images, and the relative absence of such textures in SVHN images. This makes SVHN images inherently easier to predict, which partially explains why models trained on CIFAR-10 tend to assign higher likelihoods to SVHN images. Despite this, we clearly wouldn’t ever be able to sample anything that looks like an SVHN image from a CIFAR-10-trained model, because such images are not in the typical set of the model distribution (even if their likelihood is higher).</p>
<h3 id="audio-modelling">Audio modelling</h3>
<p>I don’t believe I’ve seen any recent work that studies sampling and decoding strategies for likelihood-based models in the audio domain. Nevertheless, I wanted to briefly discuss this setting because a question I often get is: <em>“why don’t you use greedy decoding or beam search to improve the quality of WaveNet samples?”</em></p>
<p>If you’ve read this far, the answer is probably clear to you by now: because <strong>audio samples outside of the typical set sound really weird</strong>! In fact, greedy decoding from a WaveNet will invariably yield complete silence, even for fairly strongly conditioned models (e.g. WaveNets for text-to-speech synthesis). In the text-to-speech case, even if you simply reduce the sampling temperature a bit too aggressively, certain consonants that are inherently noisy (such as ‘s’, ‘f’, ‘sh’ and ‘h’, the <a href="https://en.wikipedia.org/wiki/Fricative_consonant"><em>fricatives</em></a>) will start sounding very muffled. These sounds are effectively different kinds of noise, and reducing the stochasticity of this noise has an audible effect.</p>
<h3 id="anomaly-detection">Anomaly detection</h3>
<p>Anomaly detection, or out-of-distribution (OOD) detection, is the task of identifying whether a particular input could have been drawn from a given distribution. Generative models are often used for this purpose: train an explicit model on in-distribution data, and then use its likelihood estimates to identify OOD inputs.</p>
<p>Usually, the assumption is made that OOD inputs will have low likelihoods, and in-distribution inputs will have high likelihoods. However, the fact that the mode of a high-dimensional distribution usually isn’t part of its typical set clearly contradicts this. This mistaken assumption is quite pervasive. Only recently has it started to be challenged explicitly, e.g. in works by <a href="https://arxiv.org/abs/1906.02994">Nalisnick et al.</a><sup id="fnref:oodtypicality" role="doc-noteref"><a href="#fn:oodtypicality" class="footnote">8</a></sup> and <a href="https://arxiv.org/abs/2006.09273">Morningstar et al.</a><sup id="fnref:dose" role="doc-noteref"><a href="#fn:dose" class="footnote">9</a></sup>. Both of these works propose <strong>testing the typicality of inputs, rather than simply measuring and thresholding their likelihood</strong>.</p>
<h2 id="-the-right-level-of-abstraction"><a name="right-level"></a> The right level of abstraction</h2>
<figure>
<img src="/images/levels.jpg" />
</figure>
<p>While our intuitive notion of likelihood in high-dimensional spaces might technically be wrong, it can often be a better representation of what we actually care about. This raises the question: <strong>should we really be fitting our generative models using likelihood measured in the input space?</strong> If we were to train likelihood-based models with ‘intuitive’ likelihood, they might perform better according to perceptual metrics, because they do not have to waste capacity capturing all the idiosyncrasies of particular examples that we don’t care to distinguish anyway.</p>
<p>In fact, measuring likelihood in more abstract representation spaces has had some success in generative modelling, and I think the approach should be taken more seriously in general. In language modelling, it is common to measure likelihoods at the level of word pieces, rather than individual characters. In symbolic music modelling, recent models that operate on event-based sequences (rather than sequences with a fixed time quantum) are more effective at capturing large-scale structure<sup id="fnref:perfrnn" role="doc-noteref"><a href="#fn:perfrnn" class="footnote">10</a></sup>. Some likelihood-based generative models of images separate or discard the least-significant bits of each pixel colour value, because they are less perceptually relevant, allowing model capacity to be used more efficiently<sup id="fnref:spn" role="doc-noteref"><a href="#fn:spn" class="footnote">11</a></sup> <sup id="fnref:glow" role="doc-noteref"><a href="#fn:glow" class="footnote">12</a></sup>.</p>
<p>But perhaps the most striking example is the recent line of work where VQ-VAE<sup id="fnref:vqvae" role="doc-noteref"><a href="#fn:vqvae" class="footnote">13</a></sup> is used to <strong>learn discrete higher-level representations</strong> of perceptual signals, and generative models are then trained to maximise the likelihood in this representation space. This approach has led to models that produce images that are on par with those produced by GANs in terms of fidelity, and exceed them in terms of diversity<sup id="fnref:vqvae2" role="doc-noteref"><a href="#fn:vqvae2" class="footnote">14</a></sup> <sup id="fnref:ham" role="doc-noteref"><a href="#fn:ham" class="footnote">15</a></sup> <sup id="fnref:cas" role="doc-noteref"><a href="#fn:cas" class="footnote">16</a></sup>. It has also led to models that are able to capture long-range temporal structure in audio signals, which even GANs had not been able to do before<sup id="fnref:challenge" role="doc-noteref"><a href="#fn:challenge" class="footnote">17</a></sup> <sup id="fnref:jukebox" role="doc-noteref"><a href="#fn:jukebox" class="footnote">18</a></sup>. While the current trend in representation learning is to focus on coarse-grained representations which are suitable for discriminative downstream tasks, I think it also has a very important role to play in generative modelling.</p>
<p>In the context of modelling sets with likelihood-based models, <a href="http://akosiorek.github.io/ml/2020/08/12/machine_learning_of_sets.html#what-about-those-point-processes">a recent blog post by Adam Kosiorek</a> drew my attention to point processes, and in particular, to the formula that expresses the density over ordered sequences in terms of the density over unordered sets. This formula quantifies how we need to scale probabilities across sets of different sizes to make them comparable. I think it may yet prove useful to quantify the unintuitive behaviours of likelihood-based models.</p>
<h2 id="-closing-thoughts"><a name="closing-thoughts"></a> Closing thoughts</h2>
<figure>
<img src="/images/closing_thoughts.jpg" />
</figure>
<p>To wrap up this post, here are some takeaways:</p>
<ul>
<li>
<p><strong>High-dimensional spaces</strong>, and high-dimensional probability distributions in particular, are <strong>deeply unintuitive</strong> in more ways than one. This is a well-known fact, but they still manage to surprise us sometimes!</p>
</li>
<li>
<p>The <strong>most likely samples</strong> from a high-dimensional distribution usually aren’t a very good representation of that distribution. In most situations, we probably shouldn’t be trying to find them.</p>
</li>
<li>
<p><strong>Typicality</strong> is a very useful concept to describe these unintuitive phenomena, and I think it is <strong>undervalued in machine learning</strong> — at least in the work that I’ve been exposed to.</p>
</li>
<li>
<p>A lot of work that discusses these issues (including some that I’ve highlighted in this post) <strong>doesn’t actually refer to typicality by name</strong>. I think doing so would improve our collective understanding, and shed light on links between related phenomena in different subfields.</p>
</li>
</ul>
<p>If you have any thoughts about this topic, please don’t hesitate to share them in the comments below!</p>
<p style="background-color: #eee; padding: 1em; font-size: 120%; text-align: center; border: 1px solid #ccc; border-radius: 0.5em;">
In <a href="/2020/09/01/typicality-addendum.html">an addendum to this post</a>, I explore quantitatively what happens when our intuitions fail us in high-dimensional spaces.
</p>
<p><em>If you would like to cite this post in an academic context, you can use this BibTeX snippet:</em></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@misc{dieleman2020typicality,
author = {Dieleman, Sander},
title = {Musings on typicality},
url = {https://benanne.github.io/2020/09/01/typicality.html},
year = {2020}
}
</code></pre></div></div>
<h2 id="-acknowledgements"><a name="Acknowledgements"></a> Acknowledgements</h2>
<p>Thanks to Katie Millican, Jeffrey De Fauw and Adam Kosiorek for their valuable input and feedback on this post!</p>
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:anote" role="doc-endnote">
<p>Theis, van den Oord and Bethge, “<a href="https://arxiv.org/abs/1511.01844">A note on the evaluation of generative models</a>”, International Conference on Learning Representations, 2016. <a href="#fnref:anote" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:anote:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:sixchallenges" role="doc-endnote">
<p>Koehn & Knowles, “<a href="https://arxiv.org/abs/1706.03872">Six Challenges for Neural Machine Translation</a>”, First Workshop on Neural Machine Translation, 2017. <a href="#fnref:sixchallenges" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:analyzinguncertainty" role="doc-endnote">
<p>Ott, Auli, Grangier and Ranzato, “<a href="https://arxiv.org/abs/1803.00047">Analyzing Uncertainty in Neural Machine Translation</a>”, International Conference on Machine Learning, 2018. <a href="#fnref:analyzinguncertainty" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:tradeoff" role="doc-endnote">
<p>Zhang, Duckworth, Ippolito and Neelakantan, “<a href="https://arxiv.org/abs/2004.10450">Trading Off Diversity and Quality in Natural Language Generation</a>”, arXiv, 2020. <a href="#fnref:tradeoff" class="reversefootnote" role="doc-backlink">↩</a> <a href="#fnref:tradeoff:1" class="reversefootnote" role="doc-backlink">↩<sup>2</sup></a></p>
</li>
<li id="fn:mapdecoding" role="doc-endnote">
<p>Eikema and Aziz, “<a href="https://arxiv.org/abs/2005.10283">Is MAP Decoding All You Need? The Inadequacy of the Mode in Neural Machine Translation</a>”, arXiv, 2020. <a href="#fnref:mapdecoding" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:degeneration" role="doc-endnote">
<p>Holtzman, Buys, Du, Forbes and Choi, “<a href="https://arxiv.org/abs/1904.09751">The Curious Case of Neural Text Degeneration</a>”, International Conference on Learning Representations, 2020. <a href="#fnref:degeneration" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:know" role="doc-endnote">
<p>Nalisnick, Matsukawa, Teh, Gorur and Lakshminarayanan, “<a href="https://arxiv.org/abs/1810.09136">Do Deep Generative Models Know What They Don’t Know?</a>”, International Conference on Learnign Representations, 2019. <a href="#fnref:know" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:oodtypicality" role="doc-endnote">
<p>Nalisnick, Matuskawa, Teh and Lakshminarayanan, “<a href="https://arxiv.org/abs/1906.02994">Detecting Out-of-Distribution Inputs to Deep Generative Models Using Typicality</a>”, arXiv, 2019. <a href="#fnref:oodtypicality" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:dose" role="doc-endnote">
<p>Morningstar, Ham, Gallagher, Lakshminarayanan, Alemi and Dillon, “<a href="https://arxiv.org/abs/2006.09273">Density of States Estimation for Out-of-Distribution Detection</a>”, arXiv, 2020. <a href="#fnref:dose" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:perfrnn" role="doc-endnote">
<p>Oore, Simon, Dieleman, Eck and Simonyan, “<a href="https://arxiv.org/abs/1808.03715">This Time with Feeling: Learning Expressive Musical Performance</a>”, Neural Computing and Applications, 2020. <a href="#fnref:perfrnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:spn" role="doc-endnote">
<p>Menick and Kalchbrenner, “<a href="https://arxiv.org/abs/1812.01608">Generating High Fidelity Images with Subscale Pixel Networks and Multidimensional Upscaling</a>”, International Conference on Machine Learning, 2019. <a href="#fnref:spn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:glow" role="doc-endnote">
<p>Kingma & Dhariwal, “<a href="https://arxiv.org/abs/1807.03039">Glow: Generative flow with invertible 1x1 convolutions</a>”, Neural Information Processing Systems, 2018. <a href="#fnref:glow" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae" role="doc-endnote">
<p>van den Oord, Vinyals and Kavukcuoglu, “<a href="https://arxiv.org/abs/1711.00937">https://arxiv.org/abs/1711.00937</a>”, Neural Information Processing Systems, 2017. <a href="#fnref:vqvae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae2" role="doc-endnote">
<p>Razavi, van den Oord and Vinyals, “<a href="https://arxiv.org/abs/1906.00446">Generating Diverse High-Fidelity Images with VQ-VAE-2</a>”, Neural Information Processing Systems, 2019. <a href="#fnref:vqvae2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ham" role="doc-endnote">
<p>De Fauw, Dieleman and Simonyan, “<a href="https://arxiv.org/abs/1903.04933">Hierarchical Autoregressive Image Models with Auxiliary Decoders</a>”, arXiv, 2019. <a href="#fnref:ham" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:cas" role="doc-endnote">
<p>Ravuri and Vinyals, “<a href="https://arxiv.org/abs/1905.10887">Classification Accuracy Score for Conditional Generative Models</a>”, Neural Information Processing Systems, 2019. <a href="#fnref:cas" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:challenge" role="doc-endnote">
<p>Dieleman, van den Oord and Simonyan, “<a href="https://arxiv.org/abs/1806.10474">The challenge of realistic music generation: modelling raw audio at scale</a>”, Neural Information Processing Systems, 2018. <a href="#fnref:challenge" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:jukebox" role="doc-endnote">
<p>Dhariwal, Jun, Payne, Kim, Radford and Sutskever, “<a href="https://arxiv.org/abs/2005.00341">Jukebox: A Generative Model for Music</a>”, arXiv, 2020. <a href="#fnref:jukebox" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>If you’re training or sampling from generative models, typicality is a concept worth understanding. It sheds light on why beam search doesn’t work for autoregressive models of images, audio and video; why you can’t just threshold the likelihood to perform anomaly detection with generative models; and why high-dimensional Gaussians are “soap bubbles”. This post is a summary of my current thoughts on the topic.Generating music in the waveform domain2020-03-24T00:00:00+00:002020-03-24T00:00:00+00:00https://benanne.github.io/2020/03/24/audio-generation<p>In November last year, I co-presented a tutorial on <strong>waveform-based music processing with deep learning</strong> with <a href="http://www.jordipons.me/">Jordi Pons</a> and <a href="https://jongpillee.github.io/">Jongpil Lee</a> at <a href="https://ismir2019.ewi.tudelft.nl/">ISMIR 2019</a>. Jongpil and Jordi talked about music classification and source separation respectively, and I presented the last part of the tutorial, on music generation in the waveform domain. It was very well received, so I’ve decided to write it up in the form of a blog post.</p>
<div style="float: right; width: 30%;"><a href="https://ismir2019.ewi.tudelft.nl/"><img src="/images/ismir_logo.jpg" alt="ISMIR" /></a></div>
<p>ISMIR used to be my home conference when I was a PhD student working on music information retrieval, so it was great to be back for the first time in five years. With about 450 attendees (the largest edition yet), it made for a very different experience than what I’m used to with machine learning conferences like ICML, NeurIPS and ICLR, whose audiences tend to number in the thousands these days.</p>
<p>Our tutorial on the first day of the conference gave rise to plenty of interesting questions and discussions throughout, which inspired me to write some of these things down and hopefully provide a basis to continue these discussions online. Note that I will only be covering music generation in this post, but Jordi and Jongpil are working on blog posts about their respective parts. I will share them here when they are published. In the meantime, <strong>the slide deck we used includes all three parts and is now available on <a href="https://zenodo.org/record/3529714#.XdBi0dv7Sf5">Zenodo (PDF)</a> and on <a href="https://docs.google.com/presentation/d/1_ezZXDkyhp9USAYMc5oKJCkUrUhBfo-Di8H8IfypGBM/edit#slide=id.g647f5a8648_0_57">Google slides</a></strong>. I’ve also added a few things to this post that I’ve thought of since giving the tutorial, and some new work that has come out since.</p>
<p>This is also an excellent opportunity to revive my blog, which has lain dormant for the past four years. I have taken the time to update the blog software, so if anything looks odd, that may be why. Please let me know so I can fix it!</p>
<figure>
<a href="/images/ismir_2019_photo.jpeg"><img src="/images/ismir_2019_photo.jpeg" alt="Presenting our tutorial session at ISMIR 2019 in Delft, The Netherlands." /></a>
<figcaption>Presenting our tutorial session at ISMIR 2019 in Delft, The Netherlands. Via <a href="https://twitter.com/ismir2019/status/1191341227825934336">ISMIR2019 on Twitter</a>.</figcaption>
</figure>
<h2 id="-overview"><a name="overview"></a> Overview</h2>
<p>This blog post is divided into a few different sections. I’ll try to motivate why modelling music in the waveform domain is an interesting problem. Then I’ll give an overview of generative models, the various flavours that exist, and some important ways in which they differ from each other. In the next two sections I’ll attempt to cover the state of the art in both likelihood-based and adversarial models of raw music audio. Finally, I’ll raise some observations and discussion points. If you want to skip ahead, just click the section title below to go there.</p>
<ul>
<li><em><a href="#motivation">Motivation</a></em></li>
<li><em><a href="#generative-models">Generative models</a></em></li>
<li><em><a href="#likelihood-based-models">Likelihood-based models of waveforms</a></em></li>
<li><em><a href="#adversarial-models">Adversarial models of waveforms</a></em></li>
<li><em><a href="#discussion">Discussion</a></em></li>
<li><em><a href="#conclusion">Conclusion</a></em></li>
<li><em><a href="#references">References</a></em></li>
</ul>
<p>Note that this blog post is not intended to provide an exhaustive overview of all the published research in this domain – I have tried to make a selection and I’ve inevitably left out some great work. <strong>Please don’t hesitate to suggest relevant work in the comments section!</strong></p>
<h2 id="-motivation"><a name="motivation"></a> Motivation</h2>
<h3 id="why-audio">Why audio?</h3>
<p>Music generation has traditionally been studied in the <strong>symbolic domain</strong>: the output of the generative process could be a musical score, a sequence of <a href="https://en.wikipedia.org/wiki/MIDI">MIDI events</a>, a simple melody, a sequence of chords, a textual representation<sup id="fnref:folkrnn" role="doc-noteref"><a href="#fn:folkrnn" class="footnote">1</a></sup> or some other higher-level representation. The physical process through which sound is produced is abstracted away. This dramatically reduces the amount of information that the models are required to produce, which makes the modelling problem more tractable and allows for lower-capacity models to be used effectively.</p>
<p>A very popular representation is the so-called <em>piano roll</em>, which dates back to the player pianos of the early 20th century. Holes were punched into a roll of paper to indicate which notes should be played at which time. This representation survives in digital form today and is commonly used in music production. Much of the work on music generation using machine learning has made use of (some variant of) this representation, because it allows for capturing performance-specific aspects of the music without having to model the sound.</p>
<figure class="half">
<a href="/images/player_piano.jpg"><img src="/images/player_piano.jpg" alt="Player piano with a physical piano roll inside." /></a>
<a href="/images/piano_roll.jpg"><img src="/images/piano_roll.jpg" alt="Modern incarnation of a piano roll." /></a>
<figcaption><strong>Left:</strong> player piano with a physical piano roll inside. <strong>Right:</strong> modern incarnation of a piano roll.</figcaption>
</figure>
<p>Piano rolls are great for piano performances, because they are able to exactly capture the <em>timing</em>, <em>pitch</em> and <em>velocity</em> (i.e. how hard a piano key is pressed, which is correlated with loudness, but not equivalent to it) of the notes. They are able to very accurately represent piano music, because they cover all the “degrees of freedom” that a performer has at their disposal. However, most other instruments have many more degrees of freedom: think about all the various ways you can play a note on the guitar, for example. You can decide which string to use, where to pick, whether to bend the string or not, play vibrato, … you could even play harmonics, or use two-hand tapping. Such a vast array of different playing techniques endows the performer with a lot more freedom to vary the sound that the instrument produces, and coming up with a high-level representation that can accurately capture all this variety is much more challenging. In practice, a lot of this detail is ignored and a simpler representation is often used when generating music for these instruments.</p>
<p>Modelling the sound that an instrument produces is much more difficult than modelling (some of) the parameters that are controlled by the performer, but it frees us from having to manually design high-level representations that accurately capture all these parameters. Furthermore, it allows our models to capture variability that is beyond the performer’s control: the idiosyncracies of individual instruments, for example (no two violins sound exactly the same!), or the parameters of the recording setup used to obtain the training data for our models. It also makes it possible to model ensembles of instruments, or other sound sources altogether, without having to fundamentally change anything about the model apart from the data it is trained on.</p>
<p>Digital audio representations require a reasonably high bit rate to achieve acceptable fidelity however, and modelling all these bits comes with a cost. <strong>Music audio models will necessarily have to have a much higher capacity than their symbolic counterparts</strong>, which implies higher computational requirements for model training.</p>
<h3 id="why-waveforms"><a name="why-waveforms"></a>Why waveforms?</h3>
<p>Digital representations of sound come in many shapes and forms. For reproduction, sound is usually stored by encoding the shape of the waveform as it changes over time. For analysis however, we often make use of <strong><a href="https://en.wikipedia.org/wiki/Spectrogram">spectrograms</a></strong>, both for computational methods and for visual inspection by humans. A spectrogram can be obtained from a waveform by computing the Fourier transform of overlapping windows of the signal, and stacking the results into a 2D array. This shows the <strong>local frequency content of the signal over time</strong>.</p>
<p>Spectrograms are complex-valued: they represent both the amplitude and the phase of different frequency components at each point in time. Below is a visualisation of a magnitude spectrogram and its corresponding phase spectrogram. While the magnitude spectrogram clearly exhibits a lot of structure, with sustained frequencies manifesting as horizontal lines and harmonics showing up as parallel horizontal lines, the phase spectrogram looks a lot more random.</p>
<figure>
<a href="/images/spectrogram_magnitude.png"><img src="/images/spectrogram_magnitude.png" alt="Magnitude spectrogram of a piano recording." /></a>
<a href="/images/spectrogram_phase.png"><img src="/images/spectrogram_phase.png" alt="Phase spectrogram of a piano recording." /></a>
<figcaption><strong>Top:</strong> magnitude spectrogram of a piano recording. <strong>Bottom:</strong> the corresponding phase spectrogram.</figcaption>
</figure>
<p>When extracting information from audio signals, it turns out that we can often just <strong>discard the phase component</strong>, because it is not informative for most of the things we could be interested in. In fact, this is why the magnitude spectrogram is often referred to simply as “the spectrogram”. When generating sound however, phase is very important because it meaningfully affects our perception. Listen below to an original excerpt of a piano piece, and a corresponding excerpt where the original phase has been replaced by random uniform phase information. Note how the harmony is preserved, but the timbre changes completely.</p>
<figure class="half">
<audio controls="" src="/files/original_phase.wav"><a href="/files/original_phase.wav">Audio with original phase</a></audio>
<audio controls="" src="/files/random_phase.wav"><a href="/files/random_phase.wav">Audio with random phase</a></audio>
<figcaption><strong>Left:</strong> excerpt with original phase. <strong>Right:</strong> the same excerpt with random phase.</figcaption>
</figure>
<p>The phase component of a spectrogram is tricky to model for a number of reasons:</p>
<ul>
<li>it is an <strong>angle</strong>: \(\phi \in [0, 2 \pi)\) and it wraps around;</li>
<li>it becomes <strong>effectively random</strong> as the magnitude tends towards 0, because noise starts to dominate;</li>
<li>absolute phase is less meaningful, but <strong>relative phase differences over time matter perceptually</strong>.</li>
</ul>
<p>If we model waveforms directly, we are implicitly modelling their phase as well, but we don’t run into these issues that make modelling phase so cumbersome. There are other strategies to avoid these issues, some of which I will <a href="#alternatives">discuss later</a>, but <strong>waveform modelling currently seems to be the dominant approach in the generative setting</strong>. This is particularly interesting because magnitude spectrograms are by far the most common representation used for discriminative models of audio.</p>
<h3 id="discretising-waveforms">Discretising waveforms</h3>
<p>When representing a waveform digitally, we need to <strong>discretise it in both time and amplitude</strong>. This is referred to as <a href="https://en.wikipedia.org/wiki/Pulse-code_modulation">pulse code modulation (PCM)</a>. Because audio waveforms are effectively band-limited (humans cannot perceive frequencies above ~20 kHz), the <a href="https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem">sampling theorem</a> tells us that we can discretise the waveform in time without any loss of information, as long as the sample rate is high enough (twice the highest frequency). This is why CD quality audio has a sample rate of 44.1 kHz. Much lower sample rates result in an audible loss of fidelity, but since the resulting discrete sequences also end up being much shorter, a compromise is often struck in the context of generative modelling to reduce computational requirements. Most models from literature use sample rates of 16 or 24 kHz.</p>
<figure>
<a href="/images/digital_waveform.gif"><img style="width: 100%; border: 1px solid #eee;" src="/images/digital_waveform.gif" alt="Digital waveform." /></a>
<figcaption>Digital waveform. The individual samples become visible as the zoom level increases. Figure taken from <a href="https://deepmind.com/blog/article/wavenet-generative-model-raw-audio">the original WaveNet blog post</a>.</figcaption>
</figure>
<p>When we also quantise the amplitude, some loss of fidelity is inevitable. CD quality uses 16 bits per sample, representing 2<sup>16</sup> equally spaced quantisation levels. If we want to use fewer bits, we can use logarithmically spaced quantisation levels instead to account for our nonlinear perception of loudness. This <strong><a href="https://en.wikipedia.org/wiki/%CE%9C-law_algorithm">“mu-law companding”</a></strong> will result in a smaller perceived loss of fidelity than if the levels were equally spaced.</p>
<h2 id="-generative-models"><a name="generative-models"></a> Generative models</h2>
<p>Given a dataset \(X\) of examples \(x \in X\), which we assume to have been drawn independently from some underlying distribution \(p_X(x)\), a generative model can learn to approximate this distribution \(p_X(x)\). Such a model could be used to generate new samples that look like they could have been part of the original dataset. We distinguish <em>implicit</em> and <em>explicit</em> generative models: an implicit model can produce new samples \(x \sim p_X(x)\), but cannot be used to infer the likelihood of an example (i.e. we cannot tractably compute \(p_X(x)\) given \(x\)). If we have an explicit model, we can do this, though sometimes only up to an unknown normalising constant.</p>
<h3 id="conditional-generative-models">Conditional generative models</h3>
<p>Generative models become more practically useful when we can exert some influence over the samples we draw from them. We can do this by providing a <strong>conditioning signal</strong> \(c\), which contains side information about the kind of samples we want to generate. The model is then fit to the conditional distribution \(p_X(x \vert c)\) instead of \(p_X(x)\).</p>
<p>Conditioning signals can take many shapes or forms, and it is useful to distinguish different levels of information content. The generative modelling problem becomes easier if the conditioning signal \(c\) is richer, because it reduces uncertainty about \(x\). We will refer to conditioning signals with low information content as <em>sparse conditioning</em>, and those with high information content as <em>dense conditioning</em>. Examples of conditioning signals in the image domain and the music audio domain are shown below, ordered according to density.</p>
<figure>
<img src="/images/sparse-dense-conditioning.svg" alt="Examples of sparse and dense conditioning signals in the image domain (top) and the music audio domain (bottom)." />
<figcaption>Examples of sparse and dense conditioning signals in the image domain (top) and the music audio domain (bottom).</figcaption>
</figure>
<p>Note that the density of a conditioning signal is often correlated with its level of abstraction: high-level side information tends to be more sparse. Low-level side information isn’t necessarily dense, though. For example, we could condition a generative model of music audio on a low-dimensional vector that captures the overall timbre of an instrument. This is a low-level aspect of the audio signal, but it constitutes a sparse conditioning signal.</p>
<h3 id="likelihood-based-models">Likelihood-based models</h3>
<p>Likelihood-based models directly parameterise \(p_X(x)\). The parameters \(\theta\) are then fit by maximising the likelihood of the data under the model:</p>
\[\mathcal{L}_\theta(x) = \sum_{x \in X} \log p_X(x|\theta) \quad \quad \theta^* = \arg \max_\theta \mathcal{L}_\theta(x) .\]
<p>Note that this is typically done in the log-domain because it simplifies computations and improves numerical stability. Because the model directly parameterises \(p_X(x)\), we can <strong>easily infer the likelihood of any</strong> \(x\), so we get an explicit model. Three popular flavours of likelihood-based models are autoregressive models, flow-based models and variational autoencoders. The following three subsections provide a brief overview of each.</p>
<h3 id="autoregressive-models">Autoregressive models</h3>
<p>In an autoregressive model, we assume that our examples \(x \in X\) can be treated as sequences \(\{x_i\}\). We then factorise the distribution into a product of conditionals, using the <a href="https://en.wikipedia.org/wiki/Chain_rule_(probability)">chain rule of probability</a>:</p>
\[p_X(x) = \prod_i p(x_i \vert x_{<i}) .\]
<p>These conditional distributions are typically scalar-valued and much easier to model. Because we further assume that the distribution of the sequence elements is stationary, we can share parameters and use the same model for all the factors in this product.</p>
<p>For audio signals, this is a very natural thing to do, but we can also do this for other types of structured data by arbitrarily choosing an order (e.g. raster scan order for images, as in PixelRNN<sup id="fnref:pixelrnn" role="doc-noteref"><a href="#fn:pixelrnn" class="footnote">2</a></sup> and PixelCNN<sup id="fnref:pixelcnn" role="doc-noteref"><a href="#fn:pixelcnn" class="footnote">3</a></sup>).</p>
<p>Autoregressive models are attractive because they are able to <strong>accurately capture correlations between the different elements</strong> \(x_i\) in a sequence, and they allow for fast inference (i.e. computing \(p_X(x)\) given \(x\)). Unfortunately they tend to be <strong>slow to sample from</strong>, because samples need to be drawn sequentially from the conditionals for each position in the sequence.</p>
<h3 id="flow-based-models">Flow-based models</h3>
<p>Another strategy for constructing a likelihood-based model is to use the <strong><a href="https://en.wikipedia.org/wiki/Probability_density_function#Function_of_random_variables_and_change_of_variables_in_the_probability_density_function">change of variables theorem</a></strong> to transform \(p_X(x)\) into a simple, factorised distribution \(p_Z(z)\) (standard Gaussian is a popular choice) using an invertible mapping \(x = g(z)\):</p>
\[p_X(x) = p_Z(z) \cdot |\det J|^{-1} \quad \quad J = \frac{dg(z)}{dz}.\]
<p>Here, \(J\) is the Jacobian of \(g(z)\). Models that use this approach are referred to as normalising flows or flow-based models<sup id="fnref:nice" role="doc-noteref"><a href="#fn:nice" class="footnote">4</a></sup><sup id="fnref:realnvp" role="doc-noteref"><a href="#fn:realnvp" class="footnote">5</a></sup>. They are fast both for inference and sampling, but the <strong>requirement for \(g(z)\) to be invertible significantly constrains the model architecture</strong>, and it makes them less parameter-efficient. In other words: flow-based models need to be quite large to be effective.</p>
<p>For an in-depth treatment of flow-based models, I recommend Eric Jang’s <a href="https://blog.evjang.com/2018/01/nf1.html">two-part blog post</a> on the subject, and <a href="https://arxiv.org/abs/1912.02762">Papamakarios et al.’s excellent review paper</a>.</p>
<h3 id="variational-autoencoders-vaes">Variational autoencoders (VAEs)</h3>
<p>By far the most popular class of likelihood-based generative models, I can’t avoid mentioning variational<sup id="fnref:vaerezende" role="doc-noteref"><a href="#fn:vaerezende" class="footnote">6</a></sup> autoencoders<sup id="fnref:vaekingma" role="doc-noteref"><a href="#fn:vaekingma" class="footnote">7</a></sup> – but <strong>in the context of waveform modelling, they are probably the least popular approach</strong>. In a VAE, we jointly learn two neural networks: an <em>inference network</em> \(q(z \vert x)\) learns to probabilistically map examples \(x\) into a latent space, and a <em>generative network</em> \(p(x \vert z)\) learns the distribution of the data conditioned on a latent representation \(z\). These are trained to maximise a lower bound on \(p_X(x)\), called the ELBO (Evidence Lower BOund), because computing \(p_X(x)\) given \(x\) (exact inference) is not tractable.</p>
<p>Typical VAEs assume a factorised distribution for \(p(x \vert z)\), which limits the extent to which they can capture dependencies in the data. While this is often an acceptable trade-off, in the case of waveform modelling it turns out to be a problematic restriction in practice. I believe this is why not a lot of work has been published that takes this approach (if you know of any, please point me to it). VAEs can also have more powerful decoders with fewer assumptions (autoregressive decoders, for example), but this may introduce other issues such as posterior collapse<sup id="fnref:pc" role="doc-noteref"><a href="#fn:pc" class="footnote">8</a></sup>.</p>
<p>To learn more about VAEs, check out <a href="https://jaan.io/what-is-variational-autoencoder-vae-tutorial/">Jaan Altosaar’s tutorial</a>.</p>
<h3 id="adversarial-models">Adversarial models</h3>
<p>Generative Adversarial Networks<sup id="fnref:gans" role="doc-noteref"><a href="#fn:gans" class="footnote">9</a></sup> (GANs) take a very different approach to capturing the data distribution. Two networks are trained simultaneously: a <em>generator</em> \(G\) attempts to produce examples according to the data distribution \(p_X(x)\), given latent vectors \(z\), while a <em>discriminator</em> \(D\) attempts to tell apart generated examples and real examples. In doing so, the discriminator provides a learning signal for the generator which enables it to better match the data distribution. In the original formulation, the loss function is as follows:</p>
\[\mathcal{L}(x) = \mathbb{E}_x[\log D(x)] + \mathbb{E}_z[log(1 - D(G(z)))] .\]
<p>The generator is trained to minimise this loss, whereas the discriminator attempts to maximise it. This means the training procedure is a <strong>two-player minimax game</strong>, rather than an optimisation process, as it is for most machine learning models. Balancing this game and keeping training stable has been one of the main challenges for this class of models. Many alternative formulations have been proposed to address this.</p>
<p>While adversarial and likelihood-based models are both ultimately trying to model \(p_X(x)\), they approach this target from very different angles. As a result, <strong>GANs tend to be better at producing realistic examples, but worse at capturing the full diversity of the data distribution</strong>, compared to likelihood-based models.</p>
<h3 id="more-exotic-flavours">More exotic flavours</h3>
<p>Many other strategies to learn models of complicated distributions have been proposed in literature. While research on waveform generation has chiefly focused on the two dominant paradigms of likelihood-based and adversarial models, some of these alternatives may hold promise in this area as well, so I want to mention a few that I’ve come across.</p>
<ul>
<li>
<p><strong>Energy-based models</strong> measure the “energy” of examples, and are trained by fitting the model parameters so that examples coming from the dataset have low energy, whereas all other configurations of inputs have high energy. This amounts to fitting an unnormalised density. A nice recent example is <a href="https://openai.com/blog/energy-based-models/">the work by Du & Mordatch at OpenAI</a><sup id="fnref:energy" role="doc-noteref"><a href="#fn:energy" class="footnote">10</a></sup>. Energy-based models have been around for a very long time though, and one could argue that likelihood-based models are a special case.</p>
</li>
<li>
<p><strong>Optimal transport</strong> is another approach to measure the discrepancy between probability distributions, which has served as inspiration for new variants of generative adversarial networks<sup id="fnref:wgan" role="doc-noteref"><a href="#fn:wgan" class="footnote">11</a></sup> and autoencoders<sup id="fnref:swa" role="doc-noteref"><a href="#fn:swa" class="footnote">12</a></sup>.</p>
</li>
<li>
<p><strong>Autoregressive implicit quantile networks</strong><sup id="fnref:aiqn" role="doc-noteref"><a href="#fn:aiqn" class="footnote">13</a></sup> use a similar network architecture as likelihood-based autoregressive models, but they are trained using the quantile regression loss, rather than maximimum likelihood.</p>
</li>
<li>
<p>Two continuous distributions can be matched by minimising the L2 distance between the gradients of the density functions with respect to their inputs: \(\mathcal{L}(x) = \mathbb{E} [\vert\vert \nabla_x \log p_X(x) - \nabla_y \log p_Y(y) \vert\vert ^2]\). This is called <strong>score matching</strong><sup id="fnref:scorematching" role="doc-noteref"><a href="#fn:scorematching" class="footnote">14</a></sup> and some recent works have revisited this idea for density estimation<sup id="fnref:ssm" role="doc-noteref"><a href="#fn:ssm" class="footnote">15</a></sup> and generative modelling<sup id="fnref:scorebased" role="doc-noteref"><a href="#fn:scorebased" class="footnote">16</a></sup>.</p>
</li>
<li>
<p>Please share any others that I haven’t mentioned in the comments!</p>
</li>
</ul>
<h3 id="mode-covering-vs-mode-seeking-behaviour">Mode-covering vs. mode-seeking behaviour</h3>
<p>An important consideration when determining which type of generative model is appropriate for a particular application, is the degree to which it is <em>mode-covering</em> or <em>mode-seeking</em>. When a model does not have enough capacity to capture all the variability in the data, different compromises can be made. If all examples should be reasonably likely under the model, it will have to overgeneralise and put probability mass on interpolations of examples that may not be meaningful (mode-covering). If there is no such requirement, the probability mass can be focused on a subset of examples, but then some parts of the distribution will be ignored by the model (mode-seeking).</p>
<figure>
<a href="/images/mode_seeking_covering.png"><img src="/images/mode_seeking_covering.png" alt="Illustration of mode-seeking and mode-covering behaviour in model fitting." /></a>
<figcaption>Illustration of mode-seeking and mode-covering behaviour in model fitting. The blue density represents the data distribution. The green density is our model, which is a single Gaussian. Because the data distribution is multimodal, our model does not have enough capacity to accurately capture it.</figcaption>
</figure>
<p><strong>Likelihood-based models are usually mode-covering</strong>. This is a consequence of the fact that they are fit by maximising the joint likelihood of the data. <strong>Adversarial models on the other hand are typically mode-seeking</strong>. A lot of ongoing research is focused on making it possible to control the trade-off between these two behaviours directly, without necessarily having to switch the class of models that are used.</p>
<p>In general, mode-covering behaviour is desirable in sparsely conditioned applications, where we want diversity or we expect a certain degree of “creativity” from the model. Mode-seeking behaviour is more useful in densely-conditioned settings, where most of the variability we care about is captured in the conditioning signal, and we favour realism of the generated output over diversity.</p>
<h2 id="-likelihood-based-models-of-waveforms"><a name="likelihood-based-models"></a> Likelihood-based models of waveforms</h2>
<p>In this section, I’ll try to summarise some of the key results from the past four years obtained with likelihood-based models of waveforms. While this blog post is supposed to be about music, note that many of these developments were initially targeted at generating speech, so inevitably I will also be talking about some work in the text-to-speech (TTS) domain. I recommend reading the associated papers and/or blog posts to find out more about each of these works.</p>
<h3 id="wavenet--samplernn">WaveNet & SampleRNN</h3>
<figure>
<a href="/images/wavenet.gif"><img style="display: block; margin: auto;" src="/images/wavenet.gif" alt="Wavenet sampling procedure." /></a>
<figcaption>Animation showing sampling from a WaveNet model. The model predicts the distribution of potential signal values for each timestep, given past signal values.</figcaption>
</figure>
<p>WaveNet<sup id="fnref:wavenet" role="doc-noteref"><a href="#fn:wavenet" class="footnote">17</a></sup> and SampleRNN<sup id="fnref:samplernn" role="doc-noteref"><a href="#fn:samplernn" class="footnote">18</a></sup> are <strong>autoregressive models of raw waveforms</strong>. While WaveNet is a convolutional neural network, SampleRNN uses a stack of recurrent neural networks. Both papers appeared on arXiv in late 2016 with only a few months in between, signalling that autoregressive waveform-based audio modelling was an idea whose time had come. Before then, this idea had not been seriously considered, as modelling long-term correlations in sequences across thousands of timesteps did not seem feasible with the tools that were available at that point. Furthermore, discriminative models of audio all used spectral input representations, with only a few works investigating the use of raw waveforms in this setting (and usually with worse results).</p>
<p>Although these models have their flaws (including slow sampling due to autoregressivity, and a lack of interpretability w.r.t. what actually happens inside the network), I think they constituted an important <em>existence proof</em> that encouraged further research into waveform-based models.</p>
<p>WaveNet’s strategy to deal with long-term correlations is to use <em>dilated convolutions</em>: successive convolutional layers use filters with gaps between their inputs, so that the connectivity pattern across many layers forms a tree structure (see figure above). This enables rapid growth of the receptive field, which means that <strong>a WaveNet with only a few layers can learn dependencies across many timesteps</strong>. Note that the convolutions used in WaveNet are causal (no connectivity from future to past), which forces the model to learn to predict what values the signal could take at each position in time.</p>
<p>SampleRNN’s strategy is a bit different: multiple RNNs are stacked on top of each other, with each running at a different frequency. Higher-level RNNs update less frequently, which means they can more easily capture long-range correlations and learn high-level features.</p>
<p>Both models demonstrated excellent text-to-speech results, surpassing the state of the art at the time (concatenative synthesis, for most languages) in terms of naturalness. Both models were also applied to (piano) music generation, which constituted a nice demonstration of the promise of music generation in the waveform domain, but they were clearly limited in their ability to capture longer-term musical structure.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>WaveNet</strong>: <a href="https://arxiv.org/abs/1609.03499">paper</a> - <a href="https://deepmind.com/blog/article/wavenet-generative-model-raw-audio">blog post</a><br />
<strong>SampleRNN</strong>: <a href="https://arxiv.org/abs/1612.07837">paper</a> - <a href="https://soundcloud.com/samplernn/sets">samples</a>
</p>
<h3 id="parallel-wavenet--clarinet">Parallel WaveNet & ClariNet</h3>
<p>Sampling from autoregressive models of raw audio can be quite slow and impractical. To address this issue, Parallel WaveNet<sup id="fnref:parallelwavenet" role="doc-noteref"><a href="#fn:parallelwavenet" class="footnote">19</a></sup> uses <em>probability density distillation</em> to train a model from which samples can be drawn in a single feed-forward pass. This requires a trained autoregressive WaveNet, which functions as a teacher, and an inverse autoregressive flow (IAF) model which acts as the student and learns to mimic the teacher’s predictions.</p>
<p>While an autoregressive model is slow to sample from, inferring the likelihood of a given example (and thus, maximum-likelihood training) can be done in parallel. <strong>For an inverse autoregressive flow, it’s the other way around: sampling is fast, but inference is slow</strong>. Since most practical applications rely on sampling rather than inference, such a model is often better suited. IAFs are hard to train from scratch though (because that requires inference), and the probability density distillation approach makes training them tractable.</p>
<p>Due to the nature of the probability density distillation objective, the student will end up matching the teacher’s predictions in a way that minimises the <em>reverse</em> KL divergence. This is quite unusual: likelihood-based models are typically trained to minimise the forward KL divergence instead, which is equivalent to maximising the likelihood (and minimising the reverse KL is usually intractable). While minimising the forward KL leads to mode-covering behaviour, <strong>minimising the reverse KL will instead lead to mode-seeking behaviour</strong>, which means that the model may end up ignoring certain modes in the data distribution.</p>
<p>In the text-to-speech (TTS) setting, this may actually be exactly what we want: given an excerpt of text, we want the model to generate a realistic utterance corresponding to that excerpt, but we aren’t particularly fussed about being able to generate every possible variation – one good-sounding utterance will do. This is a setting where <strong>realism is clearly more important than diversity</strong>, because all the diversity that we care about is already captured in the conditioning signal that we provide. This is usually the setting where adversarial models excel, because of their inherent mode-seeking behaviour, but using probability density distillation we can also train likelihood-based models this way.</p>
<p>To prevent the model from collapsing, parallel WaveNet uses a few additional loss terms to encourage the produced waveforms to resemble speech (such as a loss on the average power spectrum).</p>
<p>If we want to do music generation, we will typically care more about diversity because the conditioning signals we provide to the model are weaker. I believe this is why we haven’t really seen the Parallel WaveNet approach catch on outside of TTS.</p>
<p>ClariNet<sup id="fnref:clarinet" role="doc-noteref"><a href="#fn:clarinet" class="footnote">20</a></sup> was introduced as a variant of Parallel WaveNet which uses a Gaussian inverse autoregressive flow. The Gaussian assumption makes it possible to compute the reverse KL in closed form, rather than having to approximate it by sampling, which stabilises training.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Parallel WaveNet</strong>: <a href="https://arxiv.org/abs/1711.10433">paper</a> - <a href="https://deepmind.com/blog/article/high-fidelity-speech-synthesis-wavenet">blog post 1</a> - <a href="https://deepmind.com/blog/article/wavenet-launches-google-assistant">blog post 2</a><br />
<strong>ClariNet</strong>: <a href="https://arxiv.org/abs/1807.07281">paper</a> - <a href="https://clarinet-demo.github.io/">samples</a>
</p>
<h3 id="flow-based-models-waveglow-flowavenet-waveflow-blow">Flow-based models: WaveGlow, FloWaveNet, WaveFlow, Blow</h3>
<p>Training an IAF with probability density distillation isn’t the only way to train a flow-based model: most can be trained by maximum likelihood instead. In that case, the models will be encouraged to capture all the modes of the data distribution. This, in combination with their relatively low parameter efficiency (due to the invertibility requirement), means that they might need to be a bit larger to be effective. On the other hand, <strong>they allow for very fast sampling because all timesteps can be generated in parallel</strong>, so while the computational cost may be higher, sampling will still be faster in practice. Another advantage is that no additional loss terms are required to prevent collapse.</p>
<p>WaveGlow<sup id="fnref:waveglow" role="doc-noteref"><a href="#fn:waveglow" class="footnote">21</a></sup> and FloWaveNet<sup id="fnref:flowavenet" role="doc-noteref"><a href="#fn:flowavenet" class="footnote">22</a></sup>, both originally published in late 2018, are flow-based models of raw audio conditioned on mel-spectrograms, which means they can be used as <em>vocoders</em>. Because of the limited parameter efficiency of flow-based models, I suspect that it would be difficult to use them for music generation in the waveform domain, where conditioning signals are much more sparse – but they could of course be used to render mel-spectrograms generated by some other model into waveforms (more on that later).</p>
<p>WaveFlow<sup id="fnref:waveflow" role="doc-noteref"><a href="#fn:waveflow" class="footnote">23</a></sup> (with an F instead of a G) is a more recent model that improves parameter efficiency by combining the flow-based modelling approach with partial autoregressivity to model local signal structure. This allows for a trade-off between sampling speed and model size. Blow<sup id="fnref:blow" role="doc-noteref"><a href="#fn:blow" class="footnote">24</a></sup> is a flow-based model of waveforms for non-parallel voice conversion.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>WaveGlow</strong>: <a href="https://arxiv.org/abs/1811.00002">paper</a> - <a href="https://github.com/NVIDIA/waveglow">code</a> - <a href="https://nv-adlr.github.io/WaveGlow">samples</a><br />
<strong>FloWaveNet</strong>: <a href="https://arxiv.org/abs/1811.02155">paper</a> - <a href="https://github.com/ksw0306/FloWaveNet">code</a> - <a href="https://ksw0306.github.io/flowavenet-demo/">samples</a><br />
<strong>WaveFlow</strong>: <a href="https://arxiv.org/abs/1912.01219">paper</a> - <a href="https://waveflow-demo.github.io/">samples</a><br />
<strong>Blow</strong>: <a href="https://papers.nips.cc/paper/8904-blow-a-single-scale-hyperconditioned-flow-for-non-parallel-raw-audio-voice-conversion">paper</a> - <a href="https://github.com/joansj/blow">code</a> - <a href="https://blowconversions.github.io/">samples</a>
</p>
<h3 id="hierarchical-wavenets">Hierarchical WaveNets</h3>
<p>For the purpose of music generation, <strong>WaveNet is limited by its ability to capture longer-term signal structure</strong>, as previously stated. In other words: while it is clearly able to capture local signal structure very well (i.e. the timbre of an instrument), it isn’t able to model the evolution of chord progressions and melodies over longer time periods. This makes the outputs produced by this model sound rather improvisational, to put it nicely.</p>
<p>This may seem counterintuitive at first: the tree structure of the connectivity between the layers of the model should allow for a very rapid growth of its receptive field. So if you have a WaveNet model that captures up to a second of audio at a time (more than sufficient for TTS), stacking a few more dilated convolutional layers on top should suffice to grow the receptive field by several orders of magnitude (up to many minutes). At that point, the model should be able to capture any kind of meaningful musical structure.</p>
<p>In practice, however, we need to train models on excerpts of audio that are at least as long as the longest-range correlations that we want to model. So while the depth of the model has to grow only logarithmically as we increase the desired receptive field, <strong>the computational and memory requirements for training do in fact grow linearly</strong>. If we want to train a model that can learn about musical structure across tens of seconds, that will necessarily be an order of magnitude more expensive – and WaveNets that generate music already have to be quite large as it is, even with a receptive field of just one second, because <strong>music is harder to model than speech</strong>. Note also that one second of audio corresponds to a sequence of 16000 timesteps at 16 kHz, so even at a scale of seconds, we are already modelling very long sequences.</p>
<p>In 10 years, the hardware we would need to train a WaveNet with a receptive field of 30 seconds (or almost half a million timesteps at 16 kHz) may just fit in a desktop computer, so we could just wait until then to give it a try. But if we want to train such models today, we need a different strategy. If we could train separate models to capture structure at different timescales, we could have a dedicated model that focuses on capturing longer-range correlations, without having to also model local signal structure. This seems feasible, seeing as models of high-level representations of music (i.e. scores or MIDI) clearly do a much better job of capturing long-range musical structure already.</p>
<p>We can approach this as a <strong>representation learning</strong> problem: to decouple learning of local and large-scale structure, we need to extract a more compact, high-level representation \(h\) from the audio signals \(x\), that makes abstraction of local detail and has a much lower sample rate. Ideally, we would learn a model \(h = f(x)\) to extract such a representation from data (although using existing high-level representations like MIDI is also possible, as we’ll discuss later).</p>
<p>Then we can split up the task by training two separate models: a WaveNet that models the high-level representation: \(p_H(h)\), and another that models the local signal structure, conditioned on the high-level representation: \(p_{X \vert H}(x \vert h)\). The former model can focus on learning about long-range correlations, as local signal structure is not present in the representation it operates on. The latter model, on the other hand, can focus on learning about local signal structure, as relevant information about large-scale structure is readily available in its conditioning signal. Combined together, these models can be used to sample new audio signals by first sampling \(\hat{h} \sim p_H(h)\) and then \(\hat{x} \sim p_{X \vert H}(x \vert \hat{h})\).</p>
<p>We can learn both \(f(x)\) and \(p_{X \vert H}(x \vert h)\) together by training an <em>autoencoder</em>: \(f(x)\) is the encoder, a feed-forward neural network, and \(p_{X \vert H}(x \vert h)\) is the decoder, a conditional WaveNet. Learning these jointly will enable \(f(x)\) to adapt to the WaveNet, so that it extracts information that the WaveNet cannot easily model itself.</p>
<p>To make the subsequent modelling of \(h = f(x)\) with another WaveNet easier, we use a VQ-VAE<sup id="fnref:vqvae" role="doc-noteref"><a href="#fn:vqvae" class="footnote">25</a></sup>: an <strong>autoencoder with a discrete bottleneck</strong>. This has two important consequences:</p>
<ul>
<li><strong>Autoregressive models seem to be more effective on discrete sequences</strong> than on continuous ones. Making the high-level representation discrete makes the hierarchical modelling task much easier, as we don’t need to adapt the WaveNet model to work with continuous data.</li>
<li>The discreteness of the representation also <strong>limits its information capacity</strong>, forcing the autoencoder to encode only the most important information in \(h\), and to use the autoregressive connections in the WaveNet decoder to capture any local structure that wasn’t encoded in \(h\).</li>
</ul>
<p>To split the task into more than two parts, we can apply this procedure again to the high-level representation \(h\) produced by the first application, and <strong>repeat this until we get a hierarchy with as many levels as desired</strong>. Higher levels in the hierarchy make abstraction of more and more of the low-level details of the signal, and have progressively lower sample rates (yielding shorter sequences). a three-level hierarchy is shown in the diagram below. Note that <strong>each level can be trained separately and in sequence</strong>, thus greatly reducing the computational requirements of training a model with a very large receptive field.</p>
<figure>
<img src="/images/hierarchical_wavenet.svg" alt="Hierarchical WaveNet model, consisting of (conditional) autoregressive models of several levels of learnt discrete representations." />
<figcaption>Hierarchical WaveNet model, consisting of (conditional) autoregressive models of several levels of learnt discrete representations.</figcaption>
</figure>
<p>My colleagues and I explored this idea and trained hierachical WaveNet models on piano music<sup id="fnref:challenge" role="doc-noteref"><a href="#fn:challenge" class="footnote">26</a></sup>. We found that there was a trade-off between audio fidelity and long-range coherence of the generated samples. When more model capacity was repurposed to focus on long-range correlations, this reduced the capability of the model to capture local structure, resulting in lower perceived audio quality. We also conducted a human evaluation study where we asked several listeners to rate both the fidelity and the musicality of some generated samples, to demonstrate that hierarchical models produce samples which sound more musical.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Hierarchical WaveNet</strong>: <a href="https://papers.nips.cc/paper/8023-the-challenge-of-realistic-music-generation-modelling-raw-audio-at-scale">paper</a> - <a href="https://drive.google.com/drive/folders/1s7yGi928cMla8gZhfQKNXACPACSrJ9Vg">samples</a>
</p>
<h3 id="-wave2midi2wave-and-the-maestro-dataset"><a name="wave2midi2wave"></a> Wave2Midi2Wave and the MAESTRO dataset</h3>
<p>As alluded to earlier, rather than learning high-level representations of music audio from data, we could also <strong>use existing high-level representations such as MIDI</strong> to construct a hierarchical model. We can use a powerful language model to model music in the symbolic domain, and also construct a conditional WaveNet model that generates audio, given a MIDI representation. Together with my colleagues from the Magenta team at Google AI, <a href="https://magenta.tensorflow.org/maestro-wave2midi2wave">we trained such models</a> on a new dataset called MAESTRO, which features 172 hours of virtuosic piano performances, captured with fine alignment between note labels and audio waveforms<sup id="fnref:maestro" role="doc-noteref"><a href="#fn:maestro" class="footnote">27</a></sup>. This dataset is <a href="https://magenta.tensorflow.org/datasets/maestro">available to download</a> for research purposes.</p>
<p>Compared to hierarchical WaveNets with learnt intermediate representations, this approach yields much better samples in terms of musical structure, but it is limited to instruments and styles of music that MIDI can accurately represent. Manzelli et al. <a href="http://people.bu.edu/bkulis/projects/music/index.html">have demonstrated this approach</a> for a few instruments other than piano<sup id="fnref:manzellithakkar" role="doc-noteref"><a href="#fn:manzellithakkar" class="footnote">28</a></sup>, but the lack of available aligned data could pose a problem.</p>
<figure>
<img src="/images/wave2midi2wave.png" alt="Wave2Midi2Wave: a transcription model to go from audio to MIDI, a transformer to model MIDI sequences and a WaveNet to synthesise audio given a MIDI sequence." />
<figcaption>Wave2Midi2Wave: a transcription model to go from audio to MIDI, a transformer to model MIDI sequences and a WaveNet to synthesise audio given a MIDI sequence.</figcaption>
</figure>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Wave2Midi2Wave</strong>: <a href="https://openreview.net/forum?id=r1lYRjC9F7">paper</a> - <a href="https://magenta.tensorflow.org/maestro-wave2midi2wave">blog post</a> - <a href="https://storage.googleapis.com/magentadata/papers/maestro/index.html">samples</a> - <a href="https://magenta.tensorflow.org/datasets/maestro">dataset</a><br />
<strong>Manzelli et al. model</strong>: <a href="https://arxiv.org/abs/1806.09905">paper</a> - <a href="http://people.bu.edu/bkulis/projects/music/index.html">samples</a>
</p>
<h3 id="sparse-transformers">Sparse transformers</h3>
<p>OpenAI introduced the <a href="https://openai.com/blog/sparse-transformer/">Sparse Transformer</a> model<sup id="fnref:sparsetransformer" role="doc-noteref"><a href="#fn:sparsetransformer" class="footnote">29</a></sup>, a large transformer<sup id="fnref:transformer" role="doc-noteref"><a href="#fn:transformer" class="footnote">30</a></sup> with a <strong>sparse attention mechanism</strong> that scales better to long sequences than traditional attention (which is quadratic in the length of the modelled sequence). They demonstrated impressive results autoregressively modelling language, images, and music audio using this architecture, with sparse attention enabling their model to cope with waveforms of up to 65k timesteps (about 5 seconds at 12 kHz). The sparse attention mechanism seems like a good alternative to the stacked dilated convolutions of WaveNets, provided that an efficient implementation is available.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Sparse Transformer</strong>: <a href="https://arxiv.org/abs/1904.10509">paper</a> - <a href="https://openai.com/blog/sparse-transformer/">blog post</a> - <a href="https://soundcloud.com/openai_audio/sets/sparse_transformers">samples</a>
</p>
<h3 id="universal-music-translation-network">Universal music translation network</h3>
<p>An interesting conditional waveform modelling problem is that of “music translation” or “music style transfer”: given a waveform, <strong>render a new waveform where the same music is played by a different instrument</strong>. The Universal Music Translation Network<sup id="fnref:umtn" role="doc-noteref"><a href="#fn:umtn" class="footnote">31</a></sup> tackles this by training an autoencoder with multiple WaveNet decoders, where the encoded representation is encouraged to be agnostic to the instrument of the input (using an adversarial loss). A separate decoder is trained for each target instrument, so once this representation is extracted from a waveform, it can be synthesised in an instrument of choice. The separation is not perfect, but it works surprisingly well in practice. I think this is a nice example of a model that combines ideas from both likelihood-based models and the adversarial learning paradigm.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>Universal music translation network</strong>: <a href="https://openreview.net/forum?id=HJGkisCcKm">paper</a> - <a href="https://github.com/facebookresearch/music-translation">code</a> - <a href="https://musictranslation.github.io/">samples</a>
</p>
<h3 id="dadabots">Dadabots</h3>
<p><a href="http://dadabots.com">Dadabots</a> are a researcher / artist duo who have trained SampleRNN models on various albums (primarily metal) in order to produce more music in the same vein. These models aren’t great at capturing long-range correlations, so it works best for artists whose style is naturally a bit disjointed. Below is a 24 hour livestream they’ve set up with a model generating infinite technical death metal in the style of ‘Relentless Mutation’ by Archspire.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/MwtVkPKx3RA" frameborder="0" allow="accelerometer; autoplay; encrypted-media; gyroscope; picture-in-picture" allowfullscreen=""></iframe>
<h2 id="-adversarial-models-of-waveforms"><a name="adversarial-models"></a> Adversarial models of waveforms</h2>
<p>Adversarial modelling of audio has only recently started to see some successes, which is why this section is going to be a lot shorter than the previous one on likelihood-based models. The adversarial paradigm has been extremely successful in the image domain, but researchers have had a harder time translating that success to other domains and modalities, compared to likelihood-based models. As a result, published work so far has primarily focused on speech generation and the generation of individual notes or very short clips of music. As a field, we are still very much in the process of figuring out how to make GANs work well for audio at scale.</p>
<h3 id="wavegan">WaveGAN</h3>
<p>One of the first works to attempt using GANs for modelling raw audio signals is WaveGAN<sup id="fnref:wavegan" role="doc-noteref"><a href="#fn:wavegan" class="footnote">32</a></sup>. They trained a GAN on single-word speech recordings, bird vocalisations, individual drum hits and short excerpts of piano music. They also compared their raw audio-based model with a spectrogram-level model called SpecGAN. Although the fidelity of the <a href="https://chrisdonahue.com/wavegan_examples/">resulting samples</a> is far from perfect in some cases, this work undoubtedly inspired a lot of researchers to take audio modelling with GANs more seriously.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>WaveGAN</strong>: <a href="https://openreview.net/forum?id=ByMVTsR5KQ">paper</a> - <a href="https://github.com/chrisdonahue/wavegan">code</a> - <a href="https://chrisdonahue.com/wavegan_examples/">samples</a> - <a href="https://chrisdonahue.com/wavegan/">demo</a> - <a href="https://colab.research.google.com/drive/1e9o2NB2GDDjadptGr3rwQwTcw-IrFOnm">colab</a>
</p>
<h3 id="gansynth">GANSynth</h3>
<p>So far in this blog post, we have focused on generating audio waveforms directly. However, I don’t want to omit GANSynth<sup id="fnref:gansynth" role="doc-noteref"><a href="#fn:gansynth" class="footnote">33</a></sup>, even though technically speaking it does not operate directly in the waveform domain. This is because the spectral representation it uses is <strong>exactly invertible</strong> – no other models or phase reconstruction algorithms are used to turn the spectograms it generates into waveforms, which means it shares a lot of the advantages of models that operate directly in the waveform domain.</p>
<p>As <a href="#why-waveforms">discussed before</a>, modelling the phase component of a complex spectrogram is challenging, because the phase of real audio signals can seem essentially random. However, using some of its unique characteristics, we can transform the phase into a quantity that is easier to model and reason about: the <em>instantaneous frequency</em>. This is obtained by computing the temporal difference of the <em>unwrapped</em> phase between subsequent frames. “Unwrapping” means that we shift the phase component by a multiple of \(2 \pi\) for each frame as needed to make it monotonic over time, as shown in the diagram below (because phase is an angle, all values modulo \(2 \pi\) are equivalent).</p>
<p><strong>The instantaneous frequency captures how much the phase of a signal moves from one spectrogram frame to the next</strong>. For harmonic sounds, this quantity is expected to be constant over time, as the phase rotates at a constant velocity. This makes this representation particularly suitable to model musical sounds, which have a lot of harmonic content (and in fact, it might also make the representation less suitable for modelling more general classes of audio signals, though I don’t know if anyone has tried). For harmonic sounds, the instantaneous frequency is almost trivial to predict.</p>
<p>GANSynth is an adversarial model trained to produce the magnitude and instantaneous frequency spectrograms of recordings of individual musical notes. The trained model is also able to generalise to sequences of notes to some degree. <a href="https://magenta.tensorflow.org/gansynth">Check out the blog post</a> for sound examples and more information.</p>
<figure>
<img src="/images/gansynth1.png" alt="Waveform with specrogram frame boundaries indicated as dotted lines." />
<img src="/images/gansynth2.png" alt="From phase to instantaneous frequency." />
<img src="/images/gansynth3.png" alt="Visualisations of the magnitude, phase, unwrapped phase and instantaneous frequency spectra of a real recording of a note." />
<figcaption><strong>Top</strong>: waveform with specrogram frame boundaries indicated as dotted lines. <strong>Middle</strong>: from phase to instantaneous frequency. <strong>Bottom</strong>: visualisations of the magnitude, phase, unwrapped phase and instantaneous frequency spectra of a real recording of a note.</figcaption>
</figure>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>GANSynth</strong>: <a href="https://openreview.net/forum?id=H1xQVn09FX">paper</a> - <a href="http://goo.gl/magenta/gansynth-code">code</a> - <a href="http://goo.gl/magenta/gansynth-examples">samples</a> - <a href="https://magenta.tensorflow.org/gansynth">blog post</a> - <a href="http://goo.gl/magenta/gansynth-demo">colab</a>
</p>
<h3 id="-melgan--gan-tts"><a name="melgan-gantts"></a> MelGAN & GAN-TTS</h3>
<p>Two recent papers demonstrate excellent results using GANs for text-to-speech: MelGAN<sup id="fnref:melgan" role="doc-noteref"><a href="#fn:melgan" class="footnote">34</a></sup> and GAN-TTS<sup id="fnref:gantts" role="doc-noteref"><a href="#fn:gantts" class="footnote">35</a></sup>. The former also includes some music synthesis results, although fidelity is still an issue in that domain. The focus of MelGAN is inversion of magnitude spectrograms (potentially generated by other models), whereas as GAN-TTS is conditioned on the same “linguistic features” as the original WaveNet for TTS.</p>
<p>The architectures of both models share some interesting similarities, which shed light on the right inductive biases for raw waveform discriminators. Both models use <strong>multiple discriminators at different scales</strong>, each of which operates on a <strong>random window</strong> of audio extracted from the full sequence produced by the generator. This is similar to the patch-based discriminators that have occasionally been used in GANs for image generation. This windowing strategy seems to dramatically improve the capability of the generator to <strong>correctly model high frequency content</strong> in the audio signals, which is much more crucial to get right for audio than for images because it more strongly affects perceptual quality. The fact that both models benefited from this particular discriminator design indicates that we may be on the way to figuring out how to best design discriminator architectures for raw audio.</p>
<p>There are also some interesting differences: where GAN-TTS uses a combination of conditional and unconditional discriminators, MelGAN uses only unconditional discriminators and instead encourages the generator output to match the ground truth audio by adding an additional <em>feature matching</em> loss: the L1 distance between discriminator feature maps of real and generated audio. Both approaches seem to be effective.</p>
<p>Adversarial waveform synthesis is particularly useful for TTS, because it enables the use of highly parallelisable feed-forward models, which tend to have relatively low capacity requirements because they are trained with a mode-seeking loss. This means the models <strong>can more easily be deployed on low-power hardware while still performing audio synthesis in real-time</strong>, compared to autoregressive or flow-based models.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>MelGAN</strong>: <a href="https://papers.nips.cc/paper/9629-melgan-generative-adversarial-networks-for-conditional-waveform-synthesis">paper</a> - <a href="https://github.com/descriptinc/melgan-neurips">code</a> - <a href="https://melgan-neurips.github.io/">samples</a><br />
<strong>GAN-TTS</strong>: <a href="https://openreview.net/forum?id=r1gfQgSFDr">paper</a> - <a href="https://github.com/mbinkowski/DeepSpeechDistances">code (FDSD)</a> - <a href="https://storage.googleapis.com/deepmind-media/research/abstract.wav">sample</a>
</p>
<h2 id="-discussion"><a name="discussion"></a> Discussion</h2>
<p>To wrap up this blog post, I want to summarise a few thoughts about the current state of this area of research, and where things could be moving next.</p>
<h3 id="why-the-emphasis-on-likelihood-in-music-modelling">Why the emphasis on likelihood in music modelling?</h3>
<p>Clearly, the dominant paradigm for generative models of music in the waveform domain is likelihood-based. This stands in stark contrast to the image domain, where adversarial approaches greatly outnumber likelihood-based ones. I suspect there are a few reasons for this (let me know if you think of any others):</p>
<ul>
<li>
<p>Compared to likelihood-based models, it seems like it has been harder to translate the successes of adversarial models in the image domain to other domains, and to the audio domain in particular. I think this is because in a GAN, the discriminator fulfills the role of a <strong>domain-specific loss function</strong>, and important prior knowledge that guides learning is encoded in its architecture. We have known about good architectural priors for images for a long time (stacks of convolutions), as evidenced by work on e.g. style transfer<sup id="fnref:styletransfer" role="doc-noteref"><a href="#fn:styletransfer" class="footnote">36</a></sup> and the deep image prior<sup id="fnref:deepimageprior" role="doc-noteref"><a href="#fn:deepimageprior" class="footnote">37</a></sup>. For other modalities, we don’t know as much yet. It seems we are now starting to figure out what kind of architectures work for waveforms (see <a href="#melgan-gantts">MelGAN and GAN-TTS</a>, some relevant work has also been done in the discriminative setting<sup id="fnref:randomcnn" role="doc-noteref"><a href="#fn:randomcnn" class="footnote">38</a></sup>).</p>
</li>
<li>
<p><strong>Adversarial losses are mode-seeking</strong>, which makes them more suitable for settings where realism is more important than diversity (for example, because the conditioning signal contains most of the required diversity, as in TTS). In music generation, which is primarily a creative application, <strong>diversity is very important</strong>. Improving diversity of GAN samples is the subject of intense study right now, but I think it could be a while before they catch up with likelihood-based models in this sense.</p>
</li>
<li>
<p>The current disparity could also simply be a consequence of the fact that <strong>likelihood-based models got a head start</strong> in waveform modelling, with WaveNet and SampleRNN appearing on the scene in 2016 and WaveGAN in 2018.</p>
</li>
</ul>
<p>Another domain where likelihood-based models dominate is language modelling. I believe the underlying reasons for this might be a bit different though: language is inherently <strong>discrete</strong>, and extending GANs to modelling discrete data at scale is very much a work in progress. This is also more likely to be the reason why likelihood-based models are dominant for symbolic music generation as well: most symbolic representations of music are discrete.</p>
<h3 id="-alternatives-to-modelling-waveforms-directly"><a name="alternatives"></a> Alternatives to modelling waveforms directly</h3>
<p>Instead of modelling music in the waveform domain, there are many possible alternative approaches. We could model other representations of audio signals, such as spectrograms, as long as we have a way to obtain waveforms from such representations. We have quite a few options for this:</p>
<ul>
<li>
<p>We could use <strong>invertible spectrograms</strong> (i.e. phase information is not discarded), but in this case modelling the phase poses a considerable challenge. There are ways to make this easier, such as the instantaneous frequency representation used by GANSynth.</p>
</li>
<li>
<p>We could also use <strong>magnitude spectrograms</strong> (as is typically done in discriminative models of audio), and then use a <strong>phase reconstruction algorithm</strong> such as the Griffin-Lim algorithm<sup id="fnref:griffinlim" role="doc-noteref"><a href="#fn:griffinlim" class="footnote">39</a></sup> to infer a plausible phase component, based only on the generated magnitude. This approach was used for the original Tacotron model for TTS<sup id="fnref:tacotron" role="doc-noteref"><a href="#fn:tacotron" class="footnote">40</a></sup>, and for MelNet<sup id="fnref:melnet" role="doc-noteref"><a href="#fn:melnet" class="footnote">41</a></sup>, which models music audio autoregressively in the spectrogram domain.</p>
</li>
<li>
<p>Instead of a traditional phase reconstruction algorithm, we could also use a <strong>vocoder</strong> to go from spectrograms to waveforms. A vocoder, in this context, is simply a generative model in the waveform domain, conditioned on spectrograms. Vocoding is a densely conditioned generation task, and many of the models discussed before can and have been used as vocoders (e.g. WaveNet in Tacotron 2<sup id="fnref:tacotron2" role="doc-noteref"><a href="#fn:tacotron2" class="footnote">42</a></sup>, flow-based models of waveforms, or MelGAN). This approach has some advantages: generated magnitude spectrograms are often imperfect, and vocoder models can learn to account for these imperfections. Vocoders can also work with inherently lossy spectrogram representations such as mel-spectrograms and constant-Q spectrograms<sup id="fnref:constantq" role="doc-noteref"><a href="#fn:constantq" class="footnote">43</a></sup>.</p>
</li>
<li>
<p>If we are generating audio conditioned on an existing audio signal, we could also simply <strong>reuse the phase</strong> of the input signal, rather than reconstructing or generating it. This is commonly done in source separation, and the approach could also be used for music style transfer.</p>
</li>
</ul>
<p>That said, modelling spectrograms <strong>isn’t always easier</strong> than modelling waveforms. Although spectrograms have a much lower temporal resolution, they contain much more information per timestep. In autoregressive models of spectrograms, one would have to condition along both the time and frequency axes to capture all dependencies, which means we end up with roughly as many sequential sampling steps as in the raw waveform case. This is the approach taken by MelNet.</p>
<p>An alternative is to make an <strong>assumption of independence between different frequency bands at each timestep</strong>, given previous timesteps. This enables autoregressive models to produce entire spectrogram frames at a time. This partial independence assumption turns out to be an acceptable compromise in the text-to-speech domain, and is used in Tacotron and Tacotron 2. Vocoder models are particularly useful here as they can attempt to fix the imperfections resulting from this simplification of the model. I’m not sure if anybody has tried, but I would suspect that this independence assumption would cause more problems for music generation.</p>
<p>An interesting new approach combining traditional signal processing ideas with neural networks is <a href="https://magenta.tensorflow.org/ddsp">Differentiable Digital Signal Processing (DDSP)</a><sup id="fnref:ddsp" role="doc-noteref"><a href="#fn:ddsp" class="footnote">44</a></sup>. By creating learnable versions of existing DSP components and incorporating them directly into neural networks, these models are endowed with <strong>much stronger inductive biases about sound and music</strong>, and can learn to produce realistic audio with fewer trainable parameters, while also being more interpretable. I suspect that this research direction may gain a lot of traction in the near future, not in the least because the authors <a href="https://github.com/magenta/ddsp">have made their code publicly available</a>, and also because of its modularity and lower computational requirements.</p>
<figure>
<img src="/images/ddsp.png" alt="Diagram of an example DDSP model. The yellow boxes represent differentiable signal processing components." />
<figcaption>Diagram of an example DDSP model. The yellow boxes represent differentiable signal processing components. Taken from <a href="https://magenta.tensorflow.org/ddsp">the original blog post</a>.</figcaption>
</figure>
<p>Finally, we could train <strong>symbolic models of music</strong> instead: for many instruments, we already have realistic synthesisers, and we can even train them given enough data (see <a href="#wave2midi2wave">Wave2Midi2Wave</a>). If we are able to craft symbolic representations that capture the aspects of music we care about, then this is an attractive approach as it is much less computationally intensive. Magenta’s <a href="https://magenta.tensorflow.org/music-transformer">Music Transformer</a><sup id="fnref:musictransformer" role="doc-noteref"><a href="#fn:musictransformer" class="footnote">45</a></sup> and OpenAI’s <a href="https://openai.com/blog/musenet/">MuseNet</a> are two models that have recently shown impressive results in this domain, and it is likely that other ideas from the language modelling community could bring further improvements.</p>
<p style="background-color: #efe; border: 1px dashed #898; padding: 0.2em 0.5em;">
<strong>DDSP</strong>: <a href="https://openreview.net/forum?id=B1x1ma4tDr">paper</a> - <a href="https://github.com/magenta/ddsp">code</a> - <a href="https://g.co/magenta/ddsp-examples">samples</a> - <a href="https://magenta.tensorflow.org/ddsp">blog post</a> - <a href="https://g.co/magenta/ddsp-demo">colab</a><br />
<strong>Music Transformer</strong>: <a href="https://openreview.net/forum?id=rJe4ShAcF7">paper</a> - <a href="https://magenta.tensorflow.org/music-transformer">blog post</a><br />
<strong>MuseNet</strong>: <a href="https://openai.com/blog/musenet/">blog post</a>
</p>
<h3 id="whats-next">What’s next?</h3>
<p>Generative models of music in the waveform domain have seen substantial progress over the past few years, but the best results so far are still relatively easy to distinguish from real recordings, even at fairly short time scales. There is still a lot of room for improvement, but I believe a lot of this will be driven by better availability of computational resources, and not necessarily by radical innovation on the modelling front – we have great tools already, they are simply a bit expensive to use due to <strong>substantial computational requirements</strong>. As time goes on and computers get faster, hopefully this task will garner interest as it becomes accessible to more researchers.</p>
<p>One interesting question is <strong>whether adversarial models are going to catch up</strong> with likelihood-based models in this domain. I think it is quite likely that GANs, having recently made in-roads in the densely conditioned setting, will gradually be made to work for more sparsely conditioned audio generation tasks as well. Fully unconditional generation with long-term coherence seems very challenging however, and I suspect that the mode-seeking behaviour of the adversarial loss will make this much harder to achieve. A hybrid model, where a GAN captures local signal structure and another model with a different objective function captures high-level structure and long-term correlations, seems like a sensible thing to build.</p>
<p><strong>Hierarchy</strong> is a very important prior for music (and, come to think of it, for pretty much anything else we like to model), so models that explicitly incorporate this are going to have a leg up on models that don’t – at the cost of some additional complexity. Whether this additional complexity will always be worth it remains to be seen, but at the moment, this definitely seems to be the case.</p>
<p>At any rate, <strong>splitting up the problem into multiple stages</strong> that can be solved separately has been fruitful, and I think it will continue to be. So far, hierarchical models (with learnt or handcrafted intermediate representations) and spectrogram-based models with vocoders have worked well, but perhaps there are other ways to “divide and conquer”. A nice example of a different kind of split in the image domain is the one used in Subscale Pixel Networks<sup id="fnref:spn" role="doc-noteref"><a href="#fn:spn" class="footnote">46</a></sup>, where separate networks model the most and least significant bits of the image data.</p>
<h2 id="-conclusion"><a name="conclusion"></a> Conclusion</h2>
<p>If you made it to the end of this post, congratulations! I hope I’ve convinced you that music modelling in the waveform domain is an interesting research problem. It is also <strong>very far from a solved problem</strong>, so there are lots of opportunities for interesting new work. I have probably missed a lot of relevant references, especially when it comes to more recent work. If you know about relevant work that isn’t discussed here, feel free to share it in the comments! Questions about this blog post and this line of research are very welcome as well.</p>
<!-- TODO: add some bolded parts to highlight them where it makes sense. -->
<h2 id="-references"><a name="references"></a> References</h2>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:folkrnn" role="doc-endnote">
<p>Sturm, Santos, Ben-Tal and Korshunova, “<a href="https://arxiv.org/pdf/1604.08723">Music transcription modelling and composition using deep learning</a>”, Proc. 1st Conf. Computer Simulation of Musical Creativity, Huddersfield, UK, July 2016. <a href="https://folkrnn.org/">folkrnn.org</a> <a href="#fnref:folkrnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pixelrnn" role="doc-endnote">
<p>Van den Oord, Kalchbrenner and Kavukcuoglu, “<a href="https://arxiv.org/abs/1601.06759">Pixel recurrent neural networks</a>”, International Conference on Machine Learning, 2016. <a href="#fnref:pixelrnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pixelcnn" role="doc-endnote">
<p>Van den Oord, Kalchbrenner, Espeholt, Vinyals and Graves, “<a href="http://papers.nips.cc/paper/6527-conditional-image-generation-with-pixelcnn-decoders">Conditional image generation with pixelcnn decoders</a>”, Advances in neural information processing systems 29 (NeurIPS), 2016. <a href="#fnref:pixelcnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:nice" role="doc-endnote">
<p>Dinh, Krueger and Bengio, “<a href="https://arxiv.org/abs/1410.8516">NICE: Non-linear Independent Components Estimation</a>”, arXiv, 2014. <a href="#fnref:nice" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:realnvp" role="doc-endnote">
<p>Dinh, Sohl-Dickstein and Bengio, “<a href="https://arxiv.org/abs/1605.08803">Density estimation using Real NVP</a>”, arXiv, 2016. <a href="#fnref:realnvp" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vaerezende" role="doc-endnote">
<p>Rezende, Mohamed and Wierstra, “<a href="https://arxiv.org/abs/1401.4082">Stochastic Backpropagation and Approximate Inference in Deep Generative Models</a>”, International Conference on Machine Learning, 2014. <a href="#fnref:vaerezende" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vaekingma" role="doc-endnote">
<p>Kingma and Welling, “<a href="https://arxiv.org/abs/1312.6114">Auto-Encoding Variational Bayes</a>”, International Conference on Learning Representations, 2014. <a href="#fnref:vaekingma" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:pc" role="doc-endnote">
<p>Bowman, Vilnis, Vinyals, Dai, Jozefowicz and Bengio, “<a href="https://arxiv.org/abs/1511.06349">Generating Sentences from a Continuous Space</a>”, 20th SIGNLL Conference on Computational Natural Language Learning, 2016. <a href="#fnref:pc" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gans" role="doc-endnote">
<p>Goodfellow, Pouget-Abadie, Mirza, Xu, Warde-Farley, Ozair, Courville and Bengio, “<a href="http://papers.nips.cc/paper/5423-generative-adversarial-nets">Generative Adversarial Nets</a>”, Advances in neural information processing systems 27 (NeurIPS), 2014. <a href="#fnref:gans" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:energy" role="doc-endnote">
<p>Du and Mordatch, “<a href="https://arxiv.org/abs/1903.08689">https://arxiv.org/abs/1903.08689</a>”, arXiv, 2019. <a href="#fnref:energy" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:wgan" role="doc-endnote">
<p>Arjovsky, Chintala and Bottou, “<a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a>”, arXiv, 2017. <a href="#fnref:wgan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:swa" role="doc-endnote">
<p>Kolouri, Pope, Martin and Rohde, “<a href="https://arxiv.org/abs/1804.01947">Sliced-Wasserstein Autoencoder: An Embarrassingly Simple Generative Model</a>”, arXiv, 2018. <a href="#fnref:swa" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:aiqn" role="doc-endnote">
<p>Ostrovski, Dabney and Munos, “<a href="https://arxiv.org/abs/1806.05575">Autoregressive Quantile Networks for Generative Modeling</a>”, International Conference on Machine Learning, 2018. <a href="#fnref:aiqn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:scorematching" role="doc-endnote">
<p>Hyvärinen, “<a href="http://www.jmlr.org/papers/v6/hyvarinen05a.html">Estimation of Non-Normalized Statistical Models by Score Matching</a>”, Journal of Machine Learning Research, 2005. <a href="#fnref:scorematching" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ssm" role="doc-endnote">
<p>Song, Garg, Shi and Ermon, “<a href="https://arxiv.org/abs/1905.07088">Sliced Score Matching: A Scalable Approach to Density and Score Estimation</a>”, UAI, 2019. <a href="#fnref:ssm" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:scorebased" role="doc-endnote">
<p>Song and Ermon, “<a href="http://papers.nips.cc/paper/9361-generative-modeling-by-estimating-gradients-of-the-data-distribution">Generative Modeling by Estimating Gradients of the Data Distribution</a>”, Advances in neural information processing systems 32 (NeurIPS), 2019. <a href="#fnref:scorebased" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:wavenet" role="doc-endnote">
<p>Van den Oord, Dieleman, Zen, Simonyan, Vinyals, Graves, Kalchbrenner, Senior and Kavukcuoglu, “<a href="https://arxiv.org/abs/1609.03499">WaveNet: A Generative Model for Raw Audio</a>”, arXiv, 2016. <a href="#fnref:wavenet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:samplernn" role="doc-endnote">
<p>Mehri, Kumar, Gulrajani, Kumar, Jain, Sotelo, Courville and Bengio, “<a href="https://arxiv.org/abs/1612.07837">SampleRNN: An Unconditional End-to-End Neural Audio Generation Model</a>”, International Conference on Learning Representations, 2017. <a href="#fnref:samplernn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:parallelwavenet" role="doc-endnote">
<p>Van den Oord, Li, Babuschkin, Simonyan, Vinyals, Kavukcuoglu, van den Driessche, Lockhart, Cobo, Stimberg, Casagrande, Grewe, Noury, Dieleman, Elsen, Kalchbrenner, Zen, Graves, King, Walters, Belov and Hassabis, “<a href="https://arxiv.org/abs/1711.10433">Parallel WaveNet: Fast High-Fidelity Speech Synthesis</a>”, International Conference on Machine Learning, 2018. <a href="#fnref:parallelwavenet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:clarinet" role="doc-endnote">
<p>Ping, Peng and Chen, “<a href="https://arxiv.org/abs/1807.07281">ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:clarinet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:waveglow" role="doc-endnote">
<p>Prenger, Valle and Catanzaro, “<a href="https://arxiv.org/abs/1811.00002">WaveGlow: A Flow-based Generative Network for Speech Synthesis</a>”, International Conference on Acoustics, Speech, and Signal Procesing, 2019 <a href="#fnref:waveglow" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:flowavenet" role="doc-endnote">
<p>Kim, Lee, Song, Kim and Yoon, “<a href="https://arxiv.org/abs/1811.02155">FloWaveNet : A Generative Flow for Raw Audio</a>”, International Conference on Machine Learning, 2019. <a href="#fnref:flowavenet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:waveflow" role="doc-endnote">
<p>Ping, Peng, Zhao and Song, “<a href="https://arxiv.org/abs/1912.01219">WaveFlow: A Compact Flow-based Model for Raw Audio</a>”, ArXiv, 2019. <a href="#fnref:waveflow" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:blow" role="doc-endnote">
<p>Serrà, Pascual and Segura, “<a href="https://papers.nips.cc/paper/8904-blow-a-single-scale-hyperconditioned-flow-for-non-parallel-raw-audio-voice-conversion">Blow: a single-scale hyperconditioned flow for non-parallel raw-audio voice conversion</a>”, Advances in neural information processing systems 32 (NeurIPS), 2019. <a href="#fnref:blow" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:vqvae" role="doc-endnote">
<p>Van den Oord, Vinyals and Kavukcuoglu, “<a href="http://papers.nips.cc/paper/7210-neural-discrete-representation-learning">Neural Discrete Representation Learning</a>”, Advances in neural information processing systems 30 (NeurIPS), 2017. <a href="#fnref:vqvae" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:challenge" role="doc-endnote">
<p>Dieleman, Van den Oord and Simonyan, “<a href="https://papers.nips.cc/paper/8023-the-challenge-of-realistic-music-generation-modelling-raw-audio-at-scale">The challenge of realistic music generation: modelling raw audio at scale</a>”, Advances in neural information processing systems 31 (NeurIPS), 2018. <a href="#fnref:challenge" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:maestro" role="doc-endnote">
<p>Hawthorne, Stasyuk, Roberts, Simon, Huang, Dieleman, Elsen, Engel and Eck, “<a href="https://openreview.net/forum?id=r1lYRjC9F7">Enabling Factorized Piano Music Modeling and Generation with the MAESTRO Dataset</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:maestro" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:manzellithakkar" role="doc-endnote">
<p>Manzelli, Thakkar, Siahkamari and Kulis, “<a href="https://arxiv.org/abs/1806.09905">Conditioning Deep Generative Raw Audio Models for Structured Automatic Music</a>”, International Society for Music Information Retrieval Conference, 2018. <a href="#fnref:manzellithakkar" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:sparsetransformer" role="doc-endnote">
<p>Child, Gray, Radford and Sutskever, “<a href="https://arxiv.org/abs/1904.10509">Generating Long Sequences with Sparse Transformers</a>”, Arxiv, 2019. <a href="#fnref:sparsetransformer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:transformer" role="doc-endnote">
<p>Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser and Polosukhin, “<a href="http://papers.nips.cc/paper/7181-attention-is-all-you-need">Attention is All you Need</a>”, Advances in neural information processing systems 30 (NeurIPS), 2017. <a href="#fnref:transformer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:umtn" role="doc-endnote">
<p>Mor, Wolf, Polyak and Taigman, “<a href="https://openreview.net/forum?id=HJGkisCcKm">A Universal Music Translation Network</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:umtn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:wavegan" role="doc-endnote">
<p>Donahue, McAuley and Puckette, “<a href="https://openreview.net/forum?id=ByMVTsR5KQ">Adversarial Audio Synthesis</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:wavegan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gansynth" role="doc-endnote">
<p>Engel, Agrawal, Chen, Gulrajani, Donahue and Roberts, “<a href="https://openreview.net/forum?id=H1xQVn09FX">GANSynth: Adversarial Neural Audio Synthesis</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:gansynth" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:melgan" role="doc-endnote">
<p>Kumar, Kumar, de Boissiere, Gestin, Teoh, Sotelo, de Brébisson, Bengio and Courville, “<a href="https://papers.nips.cc/paper/9629-melgan-generative-adversarial-networks-for-conditional-waveform-synthesis">MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis</a>”, Advances in neural information processing systems 32 (NeurIPS), 2019. <a href="#fnref:melgan" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:gantts" role="doc-endnote">
<p>Bińkowski, Donahue, Dieleman, Clark, Elsen, Casagrande, Cobo and Simonyan, “<a href="https://openreview.net/forum?id=r1gfQgSFDr">High Fidelity Speech Synthesis with Adversarial Networks</a>”, International Conference on Learning Representations, 2020. <a href="#fnref:gantts" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:styletransfer" role="doc-endnote">
<p>Gatys, Ecker and Bethge, “<a href="http://openaccess.thecvf.com/content_cvpr_2016/html/Gatys_Image_Style_Transfer_CVPR_2016_paper.html">Image Style Transfer Using Convolutional Neural Networks</a>”, IEEE Conference on Computer Vision and Pattern Recognition, 2016. <a href="#fnref:styletransfer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:deepimageprior" role="doc-endnote">
<p>Ulyanov, Vedaldi and Lempitsky, “<a href="http://openaccess.thecvf.com/content_cvpr_2018/html/Ulyanov_Deep_Image_Prior_CVPR_2018_paper.html">Deep Image Prior</a>”, IEEE Conference on Computer Vision and Pattern Recognition, 2018. <a href="#fnref:deepimageprior" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:randomcnn" role="doc-endnote">
<p>Pons and Serra, “<a href="https://arxiv.org/abs/1805.00237">Randomly weighted CNNs for (music) audio classification</a>”, IEEE International Conference on Acoustics, Speech and Signal Processing, 2019. <a href="#fnref:randomcnn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:griffinlim" role="doc-endnote">
<p>Griffin and Lim, “<a href="https://ieeexplore.ieee.org/abstract/document/1164317/">Signal estimation from modified short-time Fourier transform</a>”, IEEE Transactions on Acoustics, Speech and Signal Processing, 1984. <a href="#fnref:griffinlim" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:tacotron" role="doc-endnote">
<p>Wang, Skerry-Ryan, Stanton, Wu, Weiss, Jaitly, Yang, Xiao, Chen, Bengio, Le, Agiomyrgiannakis, Clark and Saurous, “<a href="https://arxiv.org/abs/1703.10135">Tacotron: Towards end-to-end speech synthesis</a>”, Interspeech, 2017. <a href="#fnref:tacotron" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:melnet" role="doc-endnote">
<p>Vasquez and Lewis, “<a href="https://arxiv.org/abs/1906.01083">Melnet: A generative model for audio in the frequency domain</a>”, ArXiv, 2019. <a href="#fnref:melnet" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:tacotron2" role="doc-endnote">
<p>Shen, Pang, Weiss, Schuster, Jaitly, Yang, Chen, Zhang, Wang, Skerry-Ryan, Saurous, Agiomyrgiannakis, Wu, “<a href="https://arxiv.org/abs/1712.05884">Natural TTS synthesis by conditioning wavenet on mel spectrogram predictions</a>”, IEEE International Conference on Acoustics, Speech and Signal Processing, 2018. <a href="#fnref:tacotron2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:constantq" role="doc-endnote">
<p>Schörkhuber and Klapuri, “<a href="https://iem.kug.ac.at/fileadmin/media/iem/projects/2010/smc10_schoerkhuber.pdf">Constant-Q transform toolbox for music processing</a>”, Sound and Music Computing Conference, 2010. <a href="#fnref:constantq" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:ddsp" role="doc-endnote">
<p>Engel, Hantrakul, Gu and Roberts, “<a href="https://openreview.net/forum?id=B1x1ma4tDr">DDSP: Differentiable Digital Signal Processing</a>”, International Conference on Learning Representations, 2020. <a href="#fnref:ddsp" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:musictransformer" role="doc-endnote">
<p>Huang, Vaswani, Uszkoreit, Simon, Hawthorne, Shazeer, Dai, Hoffman, Dinculescu and Eck, “<a href="https://openreview.net/forum?id=rJe4ShAcF7">Music Transformer: Generating Music with Long-Term Structure </a>”, International Conference on Learning Representations, 2019. <a href="#fnref:musictransformer" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:spn" role="doc-endnote">
<p>Menick and Kalchbrenner, “<a href="https://openreview.net/forum?id=HylzTiC5Km">Generating High Fidelity Images with Subscale Pixel Networks and Multidimensional Upscaling</a>”, International Conference on Learning Representations, 2019. <a href="#fnref:spn" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>In November last year, I co-presented a tutorial on waveform-based music processing with deep learning with Jordi Pons and Jongpil Lee at ISMIR 2019. Jongpil and Jordi talked about music classification and source separation respectively, and I presented the last part of the tutorial, on music generation in the waveform domain. It was very well received, so I’ve decided to write it up in the form of a blog post.New Lasagne feature: arbitrary expressions as layer parameters2015-11-10T00:00:00+00:002015-11-10T00:00:00+00:00https://benanne.github.io/2015/11/10/arbitrary-expressions-as-params<p style="background-color: #ffa; padding: 1.2em;">
This post is another collaboration with <a href="http://ofai.at/~jan.schlueter">Jan Schlüter from the OFAI</a> (<a href="https://github.com/f0k">@f0k</a> on GitHub), a fellow MIR researcher and one of the lead developers of <a href="http://lasagne.readthedocs.org/">Lasagne</a>. He recently added a cool new feature that we wanted to highlight: enabling the use of arbitrary Theano expressions as layer parameters.
</p>
<p>As many of you probably know, Jan Schlüter and I are part of the team that develops <a href="http://lasagne.readthedocs.org/">Lasagne</a>, a lightweight neural network library built on top of <a href="http://deeplearning.net/software/theano/">Theano</a>.</p>
<p>One of the key <a href="http://lasagne.readthedocs.org/en/latest/user/development.html#philosophy">design principles</a> of Lasagne is <em>transparency</em>: we try not to hide Theano or numpy behind an additional layer of abstractions and encapsulation, but rather expose their functionality and data types and try to follow their conventions. This makes it very easy to learn how to use Lasagne if you already know how to use Theano – there just isn’t all that much extra to learn. But most importantly, it allows you to easily mix and match parts of Lasagne with vanilla Theano code. This is the way Lasagne is meant to be used.</p>
<p>In keeping with this philosophy, Jan recently added a feature that we’ve been discussing early on in designing the API (<a href="https://github.com/Lasagne/Lasagne/issues/11">#11</a>): it allows any learnable layer parameter to be specified as a mathematical expression evaluating to a correctly-shaped tensor. Previously, layer parameters had to be Theano shared variables, i.e., naked tensors to be learned directly. <strong>This new feature makes it possible to constrain network parameters in various, potentially creative ways.</strong> Below, we’ll go through a few examples of what is now possible that wasn’t before.</p>
<h2 id="default-case">Default case</h2>
<p>Let’s create a simple fully-connected layer of 500 units on top of an input layer of 784 units.</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">lasagne.layers</span> <span class="kn">import</span> <span class="n">InputLayer</span><span class="p">,</span> <span class="n">DenseLayer</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">l1</span> <span class="o">=</span> <span class="n">InputLayer</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">784</span><span class="p">))</span>
<span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span></code></pre></figure>
<h2 id="autoencoder-with-tied-weights">Autoencoder with tied weights</h2>
<p>Autoencoders with tied weights are a common use case, and until now implementing them in Lasagne was a bit tricky. Weight sharing in Lasagne has always been easy and intuitive:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="n">l3</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">l2</span><span class="p">.</span><span class="n">W</span><span class="p">)</span>
<span class="c1"># l2 and l3 now share the same weight matrix!</span></code></pre></figure>
<p>… but in an autoencoder, you want the weights of the decoding layer to be the <em>transpose</em> of the weights of the encoding layer. So you would do:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="n">l3</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l2</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">784</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">l2</span><span class="p">.</span><span class="n">W</span><span class="p">.</span><span class="n">T</span><span class="p">)</span></code></pre></figure>
<p>… but that didn’t work before: <code class="language-plaintext highlighter-rouge">l2.W.T</code> is a Theano expression, but not a Theano shared variable as was expected. This is counter-intuitive, and indeed, <a href="https://groups.google.com/forum/#!searchin/lasagne-users/tied$20weights/lasagne-users/ky78GBSgnBI/z10Br4p4kHMJ">people expected it to work</a> and were disappointed to find out that it didn’t. With the new feature this is no longer true. The above will work just fine. Yay!</p>
<h2 id="factorized-weights">Factorized weights</h2>
<p>To reduce the number of parameters in your network (e.g. to prevent overfitting), you could force large parameter matrices to be <em>low-rank</em> by factorizing them. In our example from before, we could factorize the 784x500 weight matrix into the product of a 784x100 and a 100x500 matrix. The number of weights of the layer then goes down from 392000 to 128400 (not including the biases).</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">import</span> <span class="nn">theano</span>
<span class="kn">import</span> <span class="nn">theano.tensor</span> <span class="k">as</span> <span class="n">T</span>
<span class="kn">from</span> <span class="nn">lasagne.init</span> <span class="kn">import</span> <span class="n">GlorotUniform</span>
<span class="kn">from</span> <span class="nn">lasagne.utils</span> <span class="kn">import</span> <span class="n">floatX</span>
<span class="n">w_init</span> <span class="o">=</span> <span class="n">GlorotUniform</span><span class="p">()</span>
<span class="n">w1</span> <span class="o">=</span> <span class="n">theano</span><span class="p">.</span><span class="n">shared</span><span class="p">(</span><span class="n">floatX</span><span class="p">(</span><span class="n">w_init</span><span class="p">((</span><span class="mi">784</span><span class="p">,</span> <span class="mi">100</span><span class="p">))))</span>
<span class="n">w2</span> <span class="o">=</span> <span class="n">theano</span><span class="p">.</span><span class="n">shared</span><span class="p">(</span><span class="n">floatX</span><span class="p">(</span><span class="n">w_init</span><span class="p">((</span><span class="mi">100</span><span class="p">,</span> <span class="mi">500</span><span class="p">))))</span>
<span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">T</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">))</span></code></pre></figure>
<p>Granted, this was possible before by inserting a biasless linear layer:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">l2_a</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">b</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l2_a</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span></code></pre></figure>
<p>Other types of factorizations <a href="http://arxiv.org/abs/1509.06569">may also be worth investigating!</a></p>
<h2 id="positive-weights">Positive weights</h2>
<p>If you want to force the weights of a layer to be positive, you can learn their logarithm:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">lasagne.init</span> <span class="kn">import</span> <span class="n">Normal</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">theano</span><span class="p">.</span><span class="n">shared</span><span class="p">(</span><span class="n">floatX</span><span class="p">(</span><span class="n">Normal</span><span class="p">(</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">mean</span><span class="o">=-</span><span class="mi">10</span><span class="p">)((</span><span class="mi">784</span><span class="p">,</span> <span class="mi">500</span><span class="p">))))</span>
<span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">T</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">w</span><span class="p">))</span></code></pre></figure>
<p>You could also use <code class="language-plaintext highlighter-rouge">T.softplus(w)</code> instead of <code class="language-plaintext highlighter-rouge">T.exp(w)</code>. You might also be tempted to try sticking a ReLU in there (<code class="language-plaintext highlighter-rouge">T.maximum(w, 0)</code>), but note that applying the linear rectifier to the weight matrix would lead to many of the underlying weights getting stuck at negative values, as the linear rectifier has zero gradient for negative inputs!</p>
<h2 id="positive-semi-definite-weights">Positive semi-definite weights</h2>
<p>There are plenty of other creative uses, such as constraining weights to be positive semi-definite (for whatever reason):</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">l2</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l1</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
<span class="n">w</span> <span class="o">=</span> <span class="n">theano</span><span class="p">.</span><span class="n">shared</span><span class="p">(</span><span class="n">floatX</span><span class="p">(</span><span class="n">w_init</span><span class="p">((</span><span class="mi">500</span><span class="p">,</span> <span class="mi">500</span><span class="p">))))</span>
<span class="n">w_psd</span> <span class="o">=</span> <span class="n">T</span><span class="p">.</span><span class="n">dot</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">w</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>
<span class="n">l3</span> <span class="o">=</span> <span class="n">DenseLayer</span><span class="p">(</span><span class="n">l2</span><span class="p">,</span> <span class="n">num_units</span><span class="o">=</span><span class="mi">500</span><span class="p">,</span> <span class="n">W</span><span class="o">=</span><span class="n">w_psd</span><span class="p">)</span></code></pre></figure>
<h2 id="limitations">Limitations</h2>
<p>There are only a couple of limitations to using Theano expressions as layer parameters. One is that Lasagne functions and methods such as <code class="language-plaintext highlighter-rouge">Layer.get_params()</code> will implicitly assume that any shared variable featuring in these Theano expressions is to be treated as a parameter. In practice that means you can’t mix learnable and non-learnable parameter variables in a single expression. Also, the same tags will apply to all shared variables in an expression. More information about parameter tags can be found in <a href="http://lasagne.readthedocs.org/en/latest/modules/layers/base.html#lasagne.layers.Layer.get_params">the documentation</a>.</p>
<p>For almost all use cases, these limitations should not be an issue. If they are, your best bet is to implement a custom layer class. Luckily, <a href="http://lasagne.readthedocs.org/en/latest/user/custom_layers.html">this is also very easy in Lasagne</a>.</p>
<h2 id="why-it-works">Why it works</h2>
<p>All of this is made possible because Lasagne builds on Theano, which takes care of backpropagating through the parameter expression to any underlying learned tensors. In frameworks building on hard-coded layer implementations rather than an automatic expression compiler, all these examples would require writing custom backpropagation code.</p>
<p>If you want to play around with this yourself, try the bleeding-edge version of Lasagne. You can find <a href="http://lasagne.readthedocs.org/en/latest/user/installation.html#bleeding-edge-version">installation instructions here</a>.</p>
<p><strong>Have fun experimenting!</strong> If you’ve done something cool that you’d like to share, feel free to send us a pull request on our <a href="https://github.com/Lasagne/Recipes">Recipes repository</a>.</p>This post is another collaboration with Jan Schlüter from the OFAI (@f0k on GitHub), a fellow MIR researcher and one of the lead developers of Lasagne. He recently added a cool new feature that we wanted to highlight: enabling the use of arbitrary Theano expressions as layer parameters.Paper about my Galaxy Challenge solution2015-03-25T00:00:00+00:002015-03-25T00:00:00+00:00https://benanne.github.io/2015/03/25/gz-paper<p><strong>UPDATE</strong> (April 27th): the paper is now available on the journal website: <a href="http://mnras.oxfordjournals.org/content/450/2/1441">http://mnras.oxfordjournals.org/content/450/2/1441</a></p>
<p>Together with Kyle Willett, one of the organizers of the <a href="http://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge">Galaxy Challenge</a>, I’ve written a paper about my winning solution for this competition. It is <a href="http://arxiv.org/abs/1503.07077">available on ArXiv</a>.</p>
<p>The paper has been accepted for publication in <a href="http://mnras.oxfordjournals.org/">MNRAS</a>, a journal on astronomy and astrophysics, but is also aimed at people with a machine learning background. Due to this dual audience, it contains both an in-depth overview of deep learning and convolutional networks, and a thorough analysis of the resulting model and its potential impact for astronomy research.</p>
<p>There is some overlap with <a href="http://benanne.github.io/2014/04/05/galaxy-zoo.html">the blog post</a> I wrote after the competition ended, but there is a lot more detail and background information, and the ‘results’ and ‘analysis’ sections are entirely new (although those of you who have seen one of my talks on the subject may have seen some of the images before).</p>
<p>I am very grateful to Kyle Willett for helping me write the manuscript. Without his help, writing a paper for an audience of astronomers would have been an impossible task for me. I believe it’s crucially important that applications of deep learning and machine learning in general get communicated to the people that could benefit from them, in such a way that they might actually consider using them.</p>
<p>I am also grateful to current and former supervisors, Joni Dambre and Benjamin Schrauwen, for supporting me when I was working on this competition and this paper, even though it is only tangentially related to the subject of my PhD.</p>
<p>Original arxiv link: <a href="http://arxiv.org/abs/1503.07077">http://arxiv.org/abs/1503.07077</a></p>UPDATE (April 27th): the paper is now available on the journal website: http://mnras.oxfordjournals.org/content/450/2/1441Classifying plankton with deep neural networks2015-03-17T00:00:00+00:002015-03-17T00:00:00+00:00https://benanne.github.io/2015/03/17/plankton<p>The <a href="https://www.kaggle.com/c/datasciencebowl">National Data Science Bowl</a>, a data science competition where the goal was to classify images of plankton, has just ended. I participated with six other members of my research lab, the <a href="http://reslab.elis.ugent.be/">Reservoir lab</a> of prof. Joni Dambre at Ghent University in Belgium. Our team finished 1st! In this post, we’ll explain our approach.</p>
<div style="float: right; width: 50%;"><a href="http://www.datasciencebowl.com/"><img src="/images/ndsb.png" alt="National Data Science Bowl" /></a></div>
<p>The <strong>≋ Deep Sea ≋</strong> team consisted of <a href="http://reslab.elis.ugent.be/aaron">Aäron van den Oord</a>, <a href="http://irakorshunova.github.io/">Ira Korshunova</a>, Jeroen Burms, <a href="http://317070.github.io/">Jonas Degrave</a>,
<a href="http://lpigou.github.io/">Lionel
Pigou</a>, <a href="https://twitter.com/pieterbuteneers">Pieter Buteneers</a> and myself. We are all master students, PhD students and post-docs at Ghent University. We decided to participate together because we are all very interested in deep learning, and a collaborative effort to solve a practical problem is a great way to learn.</p>
<p>There were seven of us, so over the course of three months, we were able to try a plethora of different things, including a bunch of recently published techniques, and a couple of novelties. This blog post was written jointly by the team and will cover all the different ingredients that went into our solution in some detail.</p>
<h2 id="-overview"><a name="overview"><a> Overview</a></a></h2>
<p>This blog post is going to be pretty long! Here’s an overview of the different sections. If you want to skip ahead, just click the section title to go there.</p>
<ul>
<li><em><a href="#introduction">Introduction</a></em></li>
<li><em><a href="#prepro-augmentation">Pre-processing and data augmentation</a></em></li>
<li><em><a href="#architecture">Network architecture</a></em></li>
<li><em><a href="#training">Training</a></em></li>
<li><em><a href="#unsupervised">Unsupervised and semi-supervised approaches</a></em></li>
<li><em><a href="#averaging">Model averaging</a></em></li>
<li><em><a href="#miscellany">Miscellany</a></em></li>
<li><em><a href="#conclusion">Conclusion</a></em></li>
</ul>
<h2 id="-introduction"><a name="introduction"><a> Introduction</a></a></h2>
<h3 id="the-problem">The problem</h3>
<p>The goal of the competition was to classify grayscale images of plankton into one of 121 classes. They were created using an underwater camera that is towed through an area. The resulting images are then used by scientists to determine which species occur in this area, and how common they are. There are typically a lot of these images, and they need to be annotated before any conclusions can be drawn. Automating this process as much as possible should save a lot of time!</p>
<p>The images obtained using the camera were already processed by a segmentation algorithm to identify and isolate individual organisms, and then cropped accordingly. Interestingly, the size of an organism in the resulting images is proportional to its actual size, and does not depend on the distance to the lens of the camera. This means that size carries useful information for the task of identifying the species. In practice it also means that all the images in the dataset have different sizes.</p>
<p>Participants were expected to build a model that produces a probability distribution across the 121 classes for each image. These predicted distributions were scored using the log loss (which corresponds to the negative log likelihood or equivalently the cross-entropy loss).</p>
<p>This loss function has some interesting properties: for one, it is extremely sensitive to overconfident predictions. If your model predicts a probability of 1 for a certain class, and it happens to be wrong, the loss becomes infinite. It is also differentiable, which means that models trained with gradient-based methods (such as neural networks) can optimize it directly - it is unnecessary to use a surrogate loss function.</p>
<p>Interestingly, optimizing the log loss is not quite the same as optimizing classification accuracy. Although the two are obviously correlated, we paid special attention to this because it was often the case that significant improvements to the log loss would barely affect the classification accuracy of the models.</p>
<h3 id="the-solution-convnets">The solution: convnets!</h3>
<p>Image classification problems are often approached using convolutional neural networks these days, and with good reason: they achieve record-breaking performance on some really difficult tasks.</p>
<p>A challenge with this competition was the size of the dataset: about 30000 examples for 121 classes. Several classes had fewer than 20 examples in total. Deep learning approaches are often said to require enormous amounts of data to work well, but recently this notion has been challenged, and our results in this competition also indicate that this is not necessarily true. Judicious use of techniques to prevent overfitting such as dropout, weight decay, data augmentation, pre-training, pseudo-labeling and parameter sharing, has enabled us to train very large models with up to 27 million parameters on this dataset.</p>
<p>Some of you may remember that I <a href="http://benanne.github.io/2014/04/05/galaxy-zoo.html">participated in another Kaggle competition last year: the Galaxy Challenge</a>. The goal of that competition was to classify images of galaxies. It turns out that a lot of the things I learned during that competition were also applicable here. Most importantly, just like images of galaxies, images of plankton are (mostly) rotation invariant. I used this property for data augmentation, and incorporated it into the model architecture.</p>
<h3 id="software-and-hardware">Software and hardware</h3>
<p>We used <a href="https://www.python.org/">Python</a>, <a href="http://www.numpy.org/">NumPy</a> and <a href="http://deeplearning.net/software/theano/">Theano</a> to implement our solution, in combination with the <a href="https://developer.nvidia.com/cuDNN">cuDNN</a> library. We also used <a href="http://mathema.tician.de/software/pycuda/">PyCUDA</a> to implement a few custom kernels.</p>
<p>Our code is mostly based on the <a href="https://github.com/benanne/Lasagne">Lasagne</a> library, which provides a bunch of layer classes and some utilities that make it easier to build neural nets in Theano. This is currently being developed by a group of researchers with different affiliations, including Aäron and myself. We hope to release the first version soon!</p>
<p>We also used <a href="http://scikit-image.org/">scikit-image</a> for pre-processing and augmentation, and <a href="https://github.com/fmder/ghalton">ghalton</a> for quasi-random number generation. During the competition, we kept track of all of our results in a Google Drive spreadsheet. Our code was hosted on a private GitHub repository, with everyone in charge of their own branch.</p>
<p>We trained our models on the NVIDIA GPUs that we have in the lab, which include GTX 980, GTX 680 and Tesla K40 cards.</p>
<h2 id="-pre-processing-and-data-augmentation"><a name="prepro-augmentation"><a> Pre-processing and data augmentation</a></a></h2>
<p>We performed very little pre-processing, other than rescaling the images in various ways and then performing global zero mean unit variance (ZMUV) normalization, to improve the stability of training and increase the convergence speed.</p>
<p>Rescaling the images was necessary because they vary in size a lot: the smallest ones are less than 40 by 40 pixels, whereas the largest ones are up to 400 by 400 pixels. We experimented with various (combinations of) rescaling strategies. For most networks, we simply rescaled the largest side of each image to a fixed length.</p>
<p>We also tried estimating the size of the creatures using <a href="http://en.wikipedia.org/wiki/Image_moment">image moments</a>. Unfortunately, centering and rescaling the images based on image moments did not improve results, but they turned out to be useful as additional features for classification (see below).</p>
<h3 id="data-augmentation">Data augmentation</h3>
<p>We augmented the data to artificially increase the size of the dataset. We used various affine transforms, and gradually increased the intensity of the augmentation as our models started to overfit more. We ended up with some pretty extreme augmentation parameters:</p>
<ul>
<li><strong>rotation</strong>: random with angle between 0° and 360° (uniform)</li>
<li><strong>translation</strong>: random with shift between -10 and 10 pixels (uniform)</li>
<li><strong>rescaling</strong>: random with scale factor between 1/1.6 and 1.6 (log-uniform)</li>
<li><strong>flipping</strong>: yes or no (bernoulli)</li>
<li><strong>shearing</strong>: random with angle between -20° and 20° (uniform)</li>
<li><strong>stretching</strong>: random with stretch factor between 1/1.3 and 1.3 (log-uniform)</li>
</ul>
<p>We augmented the data on-demand during training (<em>realtime augmentation</em>), which allowed us to combine the image rescaling and augmentation into a single affine transform. The augmentation was all done on the CPU while the GPU was training on the previous chunk of data.</p>
<figure>
<a href="/images/augmentation_noaug_cropped.png"><img style="width: 48%;" src="/images/augmentation_noaug_cropped.png" alt="" /></a>
<a href="/images/augmentation_aug_cropped.png"><img style="width: 48%;" src="/images/augmentation_aug_cropped.png" alt="" /></a>
<figcaption>Pre-processed images (left) and augmented versions of the same images (right).</figcaption>
</figure>
<p>We experimented with elastic distortions at some point, but this did not improve performance although it reduced overfitting slightly. We also tried sampling the augmentation transform parameters from gaussian instead of uniform distributions, but this did not improve results either.</p>
<h2 id="-network-architecture"><a name="architecture"><a> Network architecture</a></a></h2>
<p>Most of our convnet architectures were strongly inspired by <a href="http://arxiv.org/abs/1409.1556">OxfordNet</a>: they consist of lots of convolutional layers with 3x3 filters. We used ‘same’ convolutions (i.e. the output feature maps are the same size as the input feature maps) and overlapping pooling with window size 3 and stride 2.</p>
<p>We started with a fairly shallow models by modern standards (~ 6 layers) and gradually added more layers when we noticed it improved performance (it usually did). Near the end of the competition, we were training models with up to 16 layers. The challenge, as always, was balancing improved performance with increased overfitting.</p>
<p>We experimented with strided convolutions with 7x7 filters in the first two layers for a while, inspired by the work of <a href="http://arxiv.org/abs/1502.01852">He et al.</a>, but we were unable to achieve the same performance with this in our networks.</p>
<h3 id="cyclic-pooling">Cyclic pooling</h3>
<p>When I participated in the <a href="http://benanne.github.io/2014/04/05/galaxy-zoo.html">Galaxy Challenge</a>, one of the things I did differently from other competitors was to exploit the rotational symmetry of the images to share parameters in the network. I applied the same stack of convolutional layers to several rotated and flipped versions of the same input image, concatenated the resulting feature representations, and fed those into a stack of dense layers. This allowed the network to use the same feature extraction pipeline to “look at” the input from different angles.</p>
<p>Here, we took this a step further. Rather than concatenating the feature representations, we decided to pool across them to get rotation invariance. Here’s how it worked in practice: the images in a minibatch occur 4 times, in 4 different orientations. They are processed by the network in parallel, and at the top, the feature maps are pooled together. We decided to call this <strong>cyclic pooling</strong>, after <a href="http://en.wikipedia.org/wiki/Cyclic_group">cyclic groups</a>.</p>
<figure>
<a href="/images/cyclicpool.png"><img src="/images/cyclicpool.png" alt="" /></a>
<figcaption>Schematic representation of a convnet with cyclic pooling.</figcaption>
</figure>
<p>The nice thing about 4-way cyclic pooling is that it can be implemented very efficiently: the images are rotated by 0, 90, 180 and 270 degrees. All of these rotations can be achieved simply by transposing and flipping image axes. That means no interpolation is required.</p>
<p>Cyclic pooling also allowed us to reduce the batch size by a factor of 4: instead of having batches of 128 images, each batch now contained 32 images and was then turned into a batch with an effective size of 128 again inside the network, by stacking the original batch in 4 orientations. After the pooling step, the batch size was reduced to 32 again.</p>
<p>We tried several pooling functions over the course of the competition, as well as different positions in the network for the pooling operation (just before the output layer, between hidden layers, …). It turned out that <strong>root-mean-square pooling</strong> gave much better results than mean pooling or max pooling. We weren’t able to find a good explanation for this, but we suspect it may have something to do with rotational phase invariance.</p>
<p>One of our models pooled over 8 rotations, spaced apart 45 degrees. This required generating the input images at two angles (0 and 45 degrees). We also considered having the model do 8-way pooling by including flipped versions of each rotated image copy (<strong>dihedral pooling</strong>, after <a href="http://en.wikipedia.org/wiki/Dihedral_group">dihedral groups</a>). Unfortunately this did not work better.</p>
<h3 id="rolling-feature-maps">‘Rolling’ feature maps</h3>
<p>Cyclic pooling modestly improved our results, but it can be taken a step further. A cyclic pooling convnet extracts features from input images in four different orientations. An alternative interpretation is that its filters are applied to the input images in four different orientations. That means we can combine the stacks of feature maps from the different orientations into one big stack, and then learn the next layer of features on this combined input. As a result, the network then appears to have 4 times more filters than it actually has!</p>
<p>This is cheap to do, since the feature maps are already being computed anyway. We just have to combine them together in the right order and orientation. We named the operation that combines feature maps from different orientations a <strong>roll</strong>.</p>
<figure>
<a href="/images/cyclicroll.png"><img src="/images/cyclicroll.png" alt="" /></a>
<figcaption>Schematic representation of a roll operation inside a convnet with cyclic pooling.</figcaption>
</figure>
<p>Roll operations can be inserted after dense layers or after convolutional layers. In the latter case, care has to be taken to rotate the feature maps appropriately, so that they are all aligned.</p>
<p>We originally implemented the operations with a few lines of Theano code. This is a nice demonstration of Theano’s effectiveness for rapid prototyping of new ideas. Later on we spent some time implementing CUDA kernels for the roll operations and their gradients, because networks with many rolled layers were getting pretty slow to train. Using your own CUDA kernels with Theano turns out to be relatively easy in combination with PyCUDA. No additional C-code is required.</p>
<p>In most of the models we evaluated, we only inserted convolutional roll operations after the pooling layers, because this reduced the size of the feature maps that needed to be copied and stacked together.</p>
<p>Note that it is perfectly possible to build a cyclic pooling convnet without any roll operations, but it’s not possible to have roll operations in a network without cyclic pooling. The roll operation is only made possible because the cyclic pooling requires that each input image is processed in four different orientations to begin with.</p>
<h3 id="nonlinearities">Nonlinearities</h3>
<p>We experimented with various variants of rectified linear units (ReLUs), as well as maxout units (only in the dense layers). We also tried out smooth non-linearities and the ‘parameterized ReLUs’ that were recently introduced by <a href="http://arxiv.org/abs/1502.01852">He et al.</a>, but found networks with these units to be very prone to overfitting.</p>
<p>However, we had great success with (very) <strong>leaky ReLUs</strong>. Instead of taking the maximum of the input and zero, <code class="language-plaintext highlighter-rouge">y = max(x, 0)</code>, leaky ReLUs take the maximum of the input and a scaled version of the input, <code class="language-plaintext highlighter-rouge">y = max(x, a*x)</code>. Here, <code class="language-plaintext highlighter-rouge">a</code> is a tunable scale parameter. Setting it to zero yields regular ReLUs, and making it trainable yields parameterized ReLUs.</p>
<p>For fairly deep networks (10+ layers), we found that varying this parameter between 0 and 1/2 did not really affect the predictive performance. However, larger values in this range significantly reduced the level of overfitting. This in turn allowed us to scale up our models further. We eventually settled on a = 1/3.</p>
<h3 id="spatial-pooling">Spatial pooling</h3>
<p>We started out using networks with 2 or 3 spatial pooling layers, and we initially had some trouble getting networks with more pooling stages to work well. Most of our final models have 4 pooling stages though.</p>
<p>We started out with the traditional approach of 2x2 max-pooling, but eventually switched to 3x3 max-pooling with stride 2 (which we’ll refer to as 3x3s2), mainly because it allowed us to use a larger input size while keeping the same feature map size at the topmost convolutional layer, and without increasing the computational cost significantly.</p>
<p>As an example, a network with 80x80 input and 4 2x2 pooling stages will have feature maps of size 5x5 at the topmost convolutional layer. If we use 3x3s2 pooling instead, we can feed 95x95 input and get feature maps with the same 5x5 shape. This improved performance and only slowed down training slightly.</p>
<h3 id="multiscale-architectures">Multiscale architectures</h3>
<p>As mentioned before, the images vary widely in size, so we usually rescaled them using the largest dimension of the image as a size estimate. This is clearly suboptimal, because some species of plankton are larger than others. Size carries valuable information.</p>
<p>To allow the network to learn this, we experimented with combinations of different rescaling strategies within the same network, by combining multiple networks with different rescaled inputs together into ‘multiscale’ networks.</p>
<p>What worked best was to combine a network with inputs rescaled based on image size, and a smaller network with inputs rescaled by a fixed factor. Of course this slowed down training quite a bit, but it allowed us to squeeze out a bit more performance.</p>
<h3 id="additional-image-features">Additional image features</h3>
<p>We experimented with training small neural nets on extracted image features to ‘correct’ the predictions of our convnets. We referred to this as ‘late fusing’ because the feature network and the convnet were joined only at the output layer (before the softmax). We also tried joining them at earlier layers, but consistently found this to work worse, because of overfitting.</p>
<p>We thought this could be useful, because the features can be extracted from the raw (i.e. non-rescaled) images, so this procedure could provide additional information that is missed by the convnets. Here are some examples of types of features we evaluated (the ones we ended up using are in bold):</p>
<ul>
<li><strong>Image size in pixels</strong></li>
<li><strong>Size and shape estimates based on image moments</strong></li>
<li>Hu moments</li>
<li>Zernike moments</li>
<li>Parameter Free Threshold Adjacency Statistics</li>
<li>Linear Binary Patterns</li>
<li><strong>Haralick texture features</strong></li>
<li>Features from the competition tutorial</li>
<li>Combinations of the above</li>
</ul>
<p>The image size, the features based on image moments and the Haralick texture features were the ones that stood out the most in terms of performance. The features were fed to a neural net with two dense layers of 80 units. The final layer of the model was fused with previously generated predictions of our best convnet-based models. Using this approach, we didn’t have to retrain the convnets nor did we have to regenerate predictions (which saved us a lot of time).</p>
<p>To deal with variance due to the random weight initialization, we trained each feature network 10 times and blended the copies with uniform weights. This resulted in a consistent validation loss decrease of 0.01 (or 1.81%) on average, which was quite significant near the end of the competition.</p>
<p>Interestingly, late fusion with image size and features based on image moments seems to help just as much for multiscale models as for regular convnets. This is a bit counterintuitive: we expected both approaches to help because they could extract information about the size of the creatures, so the obtained performance improvements would overlap. The fact they were fully orthogonal was a nice surprise.</p>
<h3 id="example-convnet-architecture">Example convnet architecture</h3>
<p>Here’s an example of an architecture that works well. It has 13 layers with parameters (10 convolutional, 3 fully connected) and 4 spatial pooling layers. The input shape is <code class="language-plaintext highlighter-rouge">(32, 1, 95, 95)</code>, in <code class="language-plaintext highlighter-rouge">bc01</code> order (batch size, number of channels, height, width). The output shape is <code class="language-plaintext highlighter-rouge">(32, 121)</code>. For a given input, the network outputs 121 probabilities that sum to 1, one for each class.</p>
<table>
<thead>
<tr>
<th style="text-align: left">Layer type</th>
<th style="text-align: left">Size</th>
<th style="text-align: left">Output shape</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: left">cyclic slice</td>
<td style="text-align: left"> </td>
<td style="text-align: left">(128, 1, 95, 95)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">32 3x3 filters</td>
<td style="text-align: left">(128, 32, 95, 95)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">16 3x3 filters</td>
<td style="text-align: left">(128, 16, 95, 95)</td>
</tr>
<tr>
<td style="text-align: left">max pooling</td>
<td style="text-align: left">3x3, stride 2</td>
<td style="text-align: left">(128, 16, 47, 47)</td>
</tr>
<tr>
<td style="text-align: left">cyclic roll</td>
<td style="text-align: left"> </td>
<td style="text-align: left">(128, 64, 47, 47)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">64 3x3 filters</td>
<td style="text-align: left">(128, 64, 47, 47)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">32 3x3 filters</td>
<td style="text-align: left">(128, 32, 47, 47)</td>
</tr>
<tr>
<td style="text-align: left">max pooling</td>
<td style="text-align: left">3x3, stride 2</td>
<td style="text-align: left">(128, 32, 23, 23)</td>
</tr>
<tr>
<td style="text-align: left">cyclic roll</td>
<td style="text-align: left"> </td>
<td style="text-align: left">(128, 128, 23, 23)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">128 3x3 filters</td>
<td style="text-align: left">(128, 128, 23, 23)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">128 3x3 filters</td>
<td style="text-align: left">(128, 128, 23, 23)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">64 3x3 filters</td>
<td style="text-align: left">(128, 64, 23, 23)</td>
</tr>
<tr>
<td style="text-align: left">max pooling</td>
<td style="text-align: left">3x3, stride 2</td>
<td style="text-align: left">(128, 64, 11, 11)</td>
</tr>
<tr>
<td style="text-align: left">cyclic roll</td>
<td style="text-align: left"> </td>
<td style="text-align: left">(128, 256, 11, 11)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">256 3x3 filters</td>
<td style="text-align: left">(128, 256, 11, 11)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">256 3x3 filters</td>
<td style="text-align: left">(128, 256, 11, 11)</td>
</tr>
<tr>
<td style="text-align: left">convolution</td>
<td style="text-align: left">128 3x3 filters</td>
<td style="text-align: left">(128, 128, 11, 11)</td>
</tr>
<tr>
<td style="text-align: left">max pooling</td>
<td style="text-align: left">3x3, stride 2</td>
<td style="text-align: left">(128, 128, 5, 5)</td>
</tr>
<tr>
<td style="text-align: left">cyclic roll</td>
<td style="text-align: left"> </td>
<td style="text-align: left">(128, 512, 5, 5)</td>
</tr>
<tr>
<td style="text-align: left">fully connected</td>
<td style="text-align: left">512 2-piece maxout units</td>
<td style="text-align: left">(128, 512)</td>
</tr>
<tr>
<td style="text-align: left">cyclic pooling (rms)</td>
<td style="text-align: left"> </td>
<td style="text-align: left">(32, 512)</td>
</tr>
<tr>
<td style="text-align: left">fully connected</td>
<td style="text-align: left">512 2-piece maxout units</td>
<td style="text-align: left">(32, 512)</td>
</tr>
<tr>
<td style="text-align: left">fully connected</td>
<td style="text-align: left">121-way softmax</td>
<td style="text-align: left">(32, 121)</td>
</tr>
</tbody>
</table>
<p>Note how the ‘cyclic slice’ layer increases the batch size fourfold. The ‘cyclic pooling’ layer reduces it back to 32 again near the end. The ‘cyclic roll’ layers increase the number of feature maps fourfold.</p>
<h2 id="-training"><a name="training"><a> Training</a></a></h2>
<h3 id="validation">Validation</h3>
<p>We split off 10% of the labeled data as a validation set using stratified sampling. Due to the small size of this set, our validation estimates were relatively noisy and we periodically validated some models on the leaderboard as well.</p>
<h3 id="training-algorithm">Training algorithm</h3>
<p>We trained all of our models with stochastic gradient descent (SGD) with Nesterov momentum. We set the momentum parameter to 0.9 and did not tune it further. Most models took between 24 and 48 hours to train to convergence.</p>
<p>We trained most of the models with about 215000 gradient steps and eventually settled on a discrete learning rate schedule with two 10-fold decreases (following <a href="http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf">Krizhevsky et al.</a>), after about 180000 and 205000 gradient steps respectively. For most models we used an initial learning rate of 0.003.</p>
<p>We briefly experimented with the Adam update rule proposed by <a href="http://arxiv.org/abs/1412.6980">Kingma and Ba</a>, as an alternative to Nesterov momentum. We used the version of the algorithm described in the first version of the paper, without the lambda parameter. Although this seemed to speed up convergence by a factor of almost 2x, the results were always slightly worse than those achieved with Nesterov momentum, so we eventually abandoned this idea.</p>
<h3 id="initialization">Initialization</h3>
<p>We used a variant of the orthogonal initialization strategy proposed by <a href="http://arxiv.org/abs/1312.6120">Saxe et al.</a> everywhere. This allowed us to add as many layers as we wanted without running into any convergence problems.</p>
<h3 id="regularization">Regularization</h3>
<p>For most models, we used dropout in the fully connected layers of the network, with a dropout probability of 0.5. We experimented with dropout in the convolutional layers as well for some models.</p>
<p>We also tried Gaussian dropout (using multiplicative Gaussian noise instead of multiplicative Bernoulli noise) and found this to work about as well as traditional dropout.</p>
<p>We discovered near the end of the competition that it was useful to have a small amount of weight decay to stabilize training of larger models (so not just for its regularizing effect). Models with large fully connected layers and without weight decay would often diverge unless the learning rate was decreased considerably, which slowed things down too much.</p>
<h2 id="-unsupervised-and-semi-supervised-approaches"><a name="unsupervised"><a> Unsupervised and semi-supervised approaches</a></a></h2>
<h3 id="unsupervised-pre-training">Unsupervised pre-training</h3>
<p>Since the test set was much larger than the training set, we experimented with using unsupervised pre-training on the test set to initialize the networks. We only pre-trained the convolutional layers, using convolutional auto-encoders (CAE, <a href="http://link.springer.com/chapter/10.1007/978-3-642-21735-7_7">Masci. et al.</a>). This approach consists of building a stack of layers implementing the reverse operations (i.e. deconvolution and unpooling) of the layers that are to be pre-trained. These can then be used to try and reconstruct the input of those layers.</p>
<p>In line with the literature, we found that pre-training a network serves as an excellent regularizer (much higher train error, slightly better validation score), but the validation results with test-time augmentation (see below) were consistently slightly worse for some reason.</p>
<p>Pre-training might allow us to scale our models up further, but because they already took a long time to train, and because the pre-training itself was time-consuming as well, we did not end up doing this for any of our final models.</p>
<p>To learn useful features with unsupervised pre-training, we relied on the max-pooling and unpooling layers to serve as a sparsification of the features. We did not try a denoising autoencoder approach for two reasons: first of all, according to the results described by Masci et al., the max- and unpooling approach produces way better filters than the denoising approach, and the further improvement of combining these approaches is negligible. Secondly, due to how the networks were implemented, it would slow things down a lot.</p>
<p>We tried different setups for this pre-training stage:</p>
<ul>
<li>greedy layerwise training vs. training the full deconvolutional stack jointly: we obtained the best results when pre-training the full stack jointly. Sometimes it was necessary to initialize this stack using the greedy approach to get it to work.</li>
<li>using tied weights vs. using untied weights: Having the weights in the deconvolutional layers be the transpose of those in the corresponding convolutional layers made the (full) autoencoder easier and much faster to train. Because of this, we never got the CAE with untied weights to reconstruct the data as well as the CAE with tied weights, despite having more trainable parameters.</li>
</ul>
<p>We also tried different approaches for the supervised finetuning stage. We observed that without some modifications to our supervised training setup, there was no difference in performance between a pre-trained network and a randomly initialized one. Possibly, by the time the randomly initialized dense layers are in a suitable parameter range, the network has already forgotten a substantial amount of the information it acquired during the pre-training phase.</p>
<p>We found two ways to overcome this:</p>
<ul>
<li>
<p>keeping the pre-trained layers fixed for a while: before training the full networks, we only train the (randomly initialized) dense layers. This is quite fast since we only need to backpropagate through the top few layers. The idea is that we put the network more firmly in the basin of attraction the pre-training led us to.</p>
</li>
<li>
<p>Halving the learning rate in the convolutional layers: By having the dense layers adapt faster to the (pre-trained) convolutional layers, the network is less likely to make large changes to the pre-trained parameters before the dense layers are in a good parameter range.</p>
</li>
</ul>
<p>Both approaches produced similar results.</p>
<h3 id="pseudo-labeling">Pseudo-labeling</h3>
<p>Another way we exploited the information in the test set was by a combination of pseudo-labeling and knowledge distillation (<a href="http://arxiv.org/abs/1503.02531">Hinton et al.</a>). The initial results from models trained with pseudo-labeling were significantly better than we anticipated, so we ended up investigating this approach quite thoroughly.</p>
<p>Pseudo-labeling entails adding test data to the training set to create a much larger dataset. The labels of the test datapoints (so called pseudo-labels) are based on predictions from a previously trained model or an ensemble of models. This mostly had a regularizing effect, which allowed us to train bigger networks.</p>
<p>We experimented both with hard targets (one-hot coded) and soft targets (predicted probabilities), but quickly settled on soft targets as these gave much better results.</p>
<p>Another important detail is the balance between original data and pseudo-labeled data in the resulting dataset. In most of our experiments 33% of the minibatch was sampled from the pseudolabeled dataset and 67% from the real training set.</p>
<p>It is also possible to use more pseudo-labeled data points (e.g. 67%). In this case the model is regularized a lot more, but the results will be more similar to the pseudolabels. As mentioned before, this allowed us to train bigger networks, but in fact this is necessary to make pseudo-labeling work well. When using 67% of the pseudo-labeled dataset we even had to reduce or disable dropout, or the models would underfit.</p>
<p>Our pseudo-labeling approach differs from knowledge distillation in the sense that we use the test set instead of the training set to transfer knowledge between models. Another notable difference is that knowledge distillation is mainly intended for training smaller and faster networks that work nearly as well as bigger models, whereas we used it to train bigger models that perform <em>better</em> than the original model(s).</p>
<p>We think pseudo-labeling helped to improve our results because of the large test set and the combination of data-augmentation and test-time augmentation (see below). When pseudo-labeled test data is added to the training set, the network is optimized (or constrained) to generate predictions similar to the pseudo-labels for all possible variations and transformations of the data resulting from augmentation. This makes the network more invariant to these transformations, and forces the network to make more meaningful predictions.</p>
<p>We saw the biggest gains in the beginning (up to 0.015 improvement on the leaderboard), but even in the end we were able to improve on very large ensembles of (bagged) models (between 0.003 - 0.009).</p>
<h2 id="-model-averaging"><a name="averaging"><a> Model averaging</a></a></h2>
<p>We combined several forms of model averaging in our final submissions.</p>
<h3 id="test-time-augmentation">Test-time augmentation</h3>
<p>For each individual model, we computed predictions across various augmented versions of the input images and averaged them. This improved performance by quite a large margin. When we started doing this, our leaderboard score dropped from 0.7875 to 0.7081. We used the acronym TTA to refer to this operation.</p>
<p>Initially, we used a manually created set of affine transformations which were applied to each image to augment it. This worked better than using a set of transformations with randomly sampled parameters. After a while, we looked for better ways to tile the augmentation parameter space, and settled on a <a href="http://mathworld.wolfram.com/QuasirandomSequence.html">quasi-random</a> set of 70 transformations, using slightly more modest augmentation parameter ranges than those used for training.</p>
<p>Computing model predictions for the test set using TTA could take up to 12 hours, depending on the model.</p>
<h3 id="finding-the-optimal-transformation-instead-of-averaging">Finding the optimal transformation instead of averaging</h3>
<p>Since the TTA procedure improved the score considerably, we considered the possibility of optimizing the augmentation parameters at prediction time. This is possible because affine transformations are differentiable with respect to their parameters.</p>
<p>In order to do so, we implemented affine transformations as layers in a network, so that we could backpropagate through them. After the transformation is applied to an image, a pixel can land in between two positions of the pixel grid, which makes interpolation necessary. This makes finding these derivatives quite complex.</p>
<p>We tried various approaches to find the optimal augmentation, including the following:</p>
<ul>
<li>Optimizing the transformation parameters to maximize (or minimize) the confidence of the predictions.</li>
<li>Training a convnet to predict the optimal transformation parameters for another convnet to use.</li>
</ul>
<p>Unfortunately we were not able to improve our results with any of these approaches. This may be because selecting an optimal input augmentation as opposed to averaging across augmentations removes the regularizing effect of the averaging operation. As a consequence we did not use this technique in our final submissions, but we plan to explore this idea further.</p>
<figure>
<iframe src="https://gfycat.com/ifr/BlandEasyHamadryad" frameborder="0" scrolling="no" width="100%" style="-webkit-backface-visibility: hidden;-webkit-transform: scale(1);"></iframe>
<figcaption>Animated visualization of the optimization of the affine transformation parameters.</figcaption>
</figure>
<h3 id="combining-different-models">Combining different models</h3>
<p>In total we trained over 300 models, so we had to select how many and which models to use in the final blend. For this, we used cross-validation on our validation set. On each fold, we optimized the weights of all models to minimize the loss of the ensemble on the training part.</p>
<p>We regularly created new ensembles from a different number of top-weighted models, which we further evaluated on the testing part. In the end, this could give an approximate idea of suitable models for ensembling.</p>
<p>Once the models were selected, they were blended uniformly or with weights optimized on the validation set. Both approaches gave comparable results.</p>
<p>The models selected by this process were not necessarily the ones with the lowest TTA score. Some models with relatively poor scores were selected because they make very different predictions than our other models. A few models had poor scores due to overfitting, but were selected nevertheless because the averaging reduces the effect of overfitting.</p>
<h3 id="bagging">Bagging</h3>
<p>To improve the score of the ensemble further, we replaced some of the models by an average of 5 models (including the original one), where each model was trained on a different subset of the data.</p>
<h2 id="-miscellany"><a name="miscellany"><a> Miscellany</a></a></h2>
<p>Here are a few other things we tried, with varying levels of success:</p>
<ul>
<li>untied biases: having separate biases for each spatial location in the convolutional layer seemed to improve results very slightly.</li>
<li>winner take all nonlinearity (WTA, also known as channel-out) in the fully connected layers instead of ReLUs / maxout.</li>
<li>smooth nonlinearities: to increase the amount of variance in our blends we tried replacing the leaky rectified linear units with a smoothed version. Unfortunately this worsened our public leaderboard score.</li>
<li>specialist models: we tried training special models for highly confusable classes of chaetognaths, some protists, etc. using the knowledge distillation approach described by <a href="http://arxiv.org/abs/1503.02531">Hinton et al.</a>. We also tried a self-informed neural network structure learning (<a href="http://arxiv.org/abs/1412.6563">Warde-Farley et al.</a>), but in both cases the improvements were negligible.</li>
<li>batch normalization: unfortunately we were unable to reproduce the spectacular improvements in convergence speed described by <a href="http://arxiv.org/abs/1502.03167">Ioffe and Szegedy</a> for our models.</li>
<li>Using FaMe regularization as described by <a href="http://arxiv.org/abs/1412.6630">Rudy et al.</a> instead of dropout increased overfitting a lot. The regularizing effect seemed to be considerably weaker.</li>
<li>Semi-supervised learning with soft and hard bootstrapping as described by <a href="http://arxiv.org/abs/1412.6596">Reed et al.</a> did not improve performance or reduce overfitting.</li>
</ul>
<p>Here’s a non-exhaustive list of things that we found to reduce overfitting (including the obvious ones):</p>
<ul>
<li>dropout (various forms)</li>
<li>aggressive data augmentation</li>
<li>suitable model architectures (depth and width of the layers influence overfitting in complicated ways)</li>
<li>weight decay</li>
<li>unsupervised pre-training</li>
<li>cyclic pooling (especially with root-mean-square pooling)</li>
<li>leaky ReLUs</li>
<li>pseudo-labeling</li>
</ul>
<p>We also monitored the classification accuracy of our models during the competition. Our best models achieved an accuracy of over 82% on the validation set, and a top-5 accuracy of over 98%. This makes
it possible to use the model as a tool for speeding up manual annotation.</p>
<h2 id="-conclusion"><a name="conclusion"><a> Conclusion</a></a></h2>
<p>We had a lot of fun working on this problem together and learned a lot! If this problem interests you, be sure to check out <a href="https://www.kaggle.com/c/datasciencebowl/forums">the competition forum</a>. Many of the participants will be posting overviews of their approaches in the coming days.</p>
<p>Congratulations to the other winners, and our thanks to the competition organizers and sponsors. We would also like to thank our supervisor Joni Dambre for letting us work on this problem together.</p>
<p>We will clean up our code and put it on GitHub soon. If you have any questions or feedback about this post, feel free to leave a comment.</p>
<p><em>One of our team, Ira Korshunova, is currently looking for a good research lab to start her PhD next semester. She can be contacted at <a href="mailto:irene.korshunova@gmail.com">irene.korshunova@gmail.com</a>.</em></p>
<p><strong>UPDATE</strong> (March 25th): the code is now available on GitHub: <a href="https://github.com/benanne/kaggle-ndsb">https://github.com/benanne/kaggle-ndsb</a></p>The National Data Science Bowl, a data science competition where the goal was to classify images of plankton, has just ended. I participated with six other members of my research lab, the Reservoir lab of prof. Joni Dambre at Ghent University in Belgium. Our team finished 1st! In this post, we’ll explain our approach.The fastest convolutions in Theano with meta-optimization2014-12-09T00:00:00+00:002014-12-09T00:00:00+00:00https://benanne.github.io/2014/12/09/theano-metaopt<p style="background-color: #ffa; padding: 1.2em;">
<em>Guest post:</em> <a href="http://ofai.at/~jan.schlueter">Jan Schlüter from the OFAI</a>, a fellow MIR researcher I have met at several conferences, recently added a feature to Theano that fits so well with my <a href="//benanne.github.io/2014/04/03/faster-convolutions-in-theano.html">previous</a> <a href="//benanne.github.io/2014/05/12/fft-convolutions-in-theano.html">two</a> posts on fast convolutions that we decided to include his writeup on my blog. So enjoy the third part of the series, written by Jan!
</p>
<p>Over the past year, <a href="http://github.com/Theano/Theano">Theano</a> has accumulated several alternative implementations for 2D convolution, the most costly operation in Convolutional Neural Networks.
There is no single implementation that is the fastest for all possible image and kernel shapes,
but with Theano you can mix and match them at will.
Now mixing and matching is something that can be easily automated: Meet meta-optimization!</p>
<p>The idea is to automatically select the fastest available implementation for each individual convolution operation in a Theano function, simply by timing them.
The feature is already available in Theano: If you install the latest version from github, you can activate it by setting the environment variable <code class="language-plaintext highlighter-rouge">THEANO_FLAGS=optimizer_including=conv_meta,metaopt.verbose=1</code>.</p>
<p>In the following, I will explain what it does, how it works, and demonstrate that it can outperform all existing convnet libraries.</p>
<h2 id="batched-convolution">Batched convolution</h2>
<p>Before we begin, note that the convolution operation in Convolutional Neural Networks (CNNs) as used for Computer Vision is not just a convolution of a single 2D input image with a single 2D filter kernel.
For one, the input image can have multiple channels, such as a color image composed of three values per pixel. It can thus be expressed as a 3D tensor. To match this, the filter kernel has as many values per pixel as the input image, which makes it a 3D tensor as well. When computing the output, each channel is convolved separately with its corresponding kernel, and the resulting images are added up to produce a single 2D output image.
But usually, each convolutional layer returns a multi-channel output (a 3D tensor), which is achieved by learning multiple sets of kernels (a 4D tensor).
Finally, images are often propagated through the network in mini-batches of maybe 64 or 256 items to be processed independently, so the input and output become 4D tensors.</p>
<p>Putting everything together, the batched convolution operation convolves a <abbr title="batch size, input channels, input rows, input columns">4D input tensor</abbr> with a <abbr title="output channels, input channels, kernel rows, kernel columns">4D kernel tensor</abbr> to produce a <abbr title="batch size, output channels, output rows, output columns">4D output tensor</abbr>. Obviously, this gives ample of opportunities for parallelization. Add to this the different possible ways of computing a 2D convolution, and you can see why there are so many competing implementations.</p>
<h2 id="the-repertoire">The repertoire</h2>
<p>As an actively maintained open-source project with several external contributors, Theano has grown to have access to five convolution implementations:</p>
<ul>
<li>a “legacy” implementation that has been created for Theano</li>
<li>Alex Krizhevsky’s <strong><a href="http://code.google.com/p/cuda-convnet">cuda-convnet</a></strong>, via a wrapper already <a href="//benanne.github.io/2014/04/03/faster-convolutions-in-theano.html">described by Sander</a></li>
<li>an <strong>FFT-based convolution</strong> <a href="//benanne.github.io/2014/05/12/fft-convolutions-in-theano.html">started by Sander</a> and <a href="https://github.com/Theano/Theano/pull/1870">finished by Arnaud Bergeron</a></li>
<li>the <strong>gemm-based convolution</strong> from <a href="http://caffe.berkeleyvision.org">Caffe</a>, <a href="https://github.com/Theano/Theano/pull/2002">started by Arjun Jain and Frédéric Bastién</a> and <a href="https://github.com/Theano/Theano/pull/2033">finished by me</a></li>
<li>Nvidia’s new <strong><a href="https://developer.nvidia.com/cuDNN">cuDNN</a> library</strong>, via a wrapper done by <a href="https://github.com/Theano/Theano/pull/2096">Arnaud</a> and subsequently improved by <a href="https://github.com/Theano/Theano/issues?q=dnn+is%3Aclosed+author%3Anouiz">Frédéric</a> and <a href="https://github.com/Theano/Theano/issues?q=dnn+is%3Aclosed+author%3Af0k">me</a></li>
</ul>
<p>All of these have their strengths and weaknesses.
cuda-convnet only supports square kernels and places several restrictions on the number of input and output channels and the batch size.
The FFT-based based convolution is applicable to any configuration, but requires a lot of extra memory that practically limits it to small batch and image sizes or very potent graphics cards.
cuDNN requires a GPU of Compute Capability 3.0 or above,
and the convolution ported from Caffe needs some extra memory again.
Finally, the legacy implementation comes free of limitations, but is usually the slowest of the pack.</p>
<p>Depending on the configuration – that is, the batch size, image shape, filter count and kernel shape –, any of these five implementations can be the fastest.</p>
<h2 id="three-convolutions-per-layer">Three convolutions per layer</h2>
<p>To complicate matters, each convolutional layer in a convnet actually results in three batched convolution operations to be performed in training:</p>
<ol>
<li>The <strong>forward pass</strong>, a valid convolution of images and kernels</li>
<li>The <strong>gradient wrt. weights</strong>, a valid convolution of images and the gradient wrt. output</li>
<li>The <strong>gradient wrt. input</strong>, a full convolution of the kernels and the gradient wrt. output</li>
</ol>
<p>For a valid convolution, the kernel is applied wherever it completely overlaps with the input (i.e., it only touches valid data).
For a full convolution, it is applied wherever it overlaps with the input by at least one pixel –
this is equivalent to padding the input with a suitably-sized symmetric border of zeros and applying a valid convolution.</p>
<p>(For the eager ones: The third one in the list above is actually a correlation, because the kernels are not flipped as in the forward pass. And the second one requires the batch size and channels of the input, kernel and output tensors to be swapped. Still all of these can be expressed using the batched convolution operation described in the beginning.)</p>
<p>The “big libraries” (cuda-convnet, Caffe and cuDNN) each come with three algorithms specialized for these three cases, while the FFT-based convolution just distinguishes between valid and full convolutions.</p>
<h2 id="cherry-picking">Cherry-picking</h2>
<p>A lot of my work on Theano’s convolution was triggered by following Soumith Chintala’s <a href="https://github.com/soumith/convnet-benchmarks">convnet-benchmarks</a> initiative, which set out to compare all freely available Convolutional Neural Network libraries in terms of their performance.
When looking at <a href="https://github.com/soumith/convnet-benchmarks/blob/88d4f3b41d86782a8fa1e098c9789c4674bbddb3/README.md">some of the first results posted</a>, the first thing I noticed was that it would pay off to use a different library for each of the five configurations tested. This has quickly been included as a hypothetical “cherry-picking” row into the result tables.</p>
<p>I took over maintenance of Soumith’s Theano benchmark script and evolved it into a handy little tool to compare its convolution implementations for different configurations. Feel free to <a href="https://github.com/soumith/convnet-benchmarks/tree/master/theano">download the script</a> and follow along.</p>
<p>So let’s see what we could gain with cherry-picking in Theano:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>$ SKIP=meta python pylearn2_benchmark.py i3x64x64,k128x7x7,b64
Using gpu device 0: GeForce GTX 780 Ti
CONFIG: input = 3 x 64 x 64 * ker = 3 x 128 x 7 x 7 ( bs = 64 , stride = 1 )
theano.tensor.nnet.conv.conv2d ==> fprop ==> 43
theano.tensor.nnet.conv.conv2d ==> bprop inputs ==> 44
theano.tensor.nnet.conv.conv2d ==> bprop weights ==> 185
theano.sandbox.cuda.fftconv.conv2d_fft ==> fprop ==> 19
theano.sandbox.cuda.fftconv.conv2d_fft ==> bprop inputs ==> 26
theano.sandbox.cuda.fftconv.conv2d_fft ==> bprop weights ==> 20
(auto) theano.sandbox.cuda.dnn.GpuDnnConv ==> fprop ==> 4
(auto) theano.sandbox.cuda.dnn.GpuDnnConv ==> bprop inputs ==> 7
(auto) theano.sandbox.cuda.dnn.GpuDnnConv ==> bprop weights ==> 6
(auto) theano.sandbox.cuda.blas.GpuCorrMM ==> fprop ==> 6
(auto) theano.sandbox.cuda.blas.GpuCorrMM ==> bprop inputs ==> 7
(auto) theano.sandbox.cuda.blas.GpuCorrMM ==> bprop weights ==> 10
pylearn2.sandbox.cuda_convnet(partial_sum=None) ==> fprop ==> 7
pylearn2.sandbox.cuda_convnet(partial_sum=None) ==> bprop inputs ==> 11
pylearn2.sandbox.cuda_convnet(partial_sum=None) ==> bprop weights ==> 47
pylearn2.sandbox.cuda_convnet(partial_sum=1) ==> fprop ==> 7
pylearn2.sandbox.cuda_convnet(partial_sum=1) ==> bprop inputs ==> 11
pylearn2.sandbox.cuda_convnet(partial_sum=1) ==> bprop weights ==> 13
</code></pre></div></div>
<p>What we see here are the respective computation times in milliseconds for a particular configuration (tensor shapes) for the legacy implementation, FFT-based convolution, cuDNN, gemm-based convolution and cuda-convnet (with two different values for a tuning parameter).
For this layer, cuDNN would be the optimal choice.</p>
<p>Let’s try a second configuration:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>$ SKIP=meta python pylearn2_benchmark.py i32x15x80,k64x5x5,b256
Using gpu device 0: GeForce GTX 780 Ti
CONFIG: input = 32 x 15 x 80 * ker = 32 x 64 x 5 x 5 ( bs = 256 , stride = 1 )
theano.tensor.nnet.conv.conv2d ==> fprop ==> 146
theano.tensor.nnet.conv.conv2d ==> bprop inputs ==> 182
theano.tensor.nnet.conv.conv2d ==> bprop weights ==> 162
theano.sandbox.cuda.fftconv.conv2d_fft ==> fprop ==> 20
theano.sandbox.cuda.fftconv.conv2d_fft ==> bprop inputs ==> 24
theano.sandbox.cuda.fftconv.conv2d_fft ==> bprop weights ==> 15
(auto) theano.sandbox.cuda.dnn.GpuDnnConv ==> fprop ==> 18
(auto) theano.sandbox.cuda.dnn.GpuDnnConv ==> bprop inputs ==> 23
(auto) theano.sandbox.cuda.dnn.GpuDnnConv ==> bprop weights ==> 25
(auto) theano.sandbox.cuda.blas.GpuCorrMM ==> fprop ==> 22
(auto) theano.sandbox.cuda.blas.GpuCorrMM ==> bprop inputs ==> 29
(auto) theano.sandbox.cuda.blas.GpuCorrMM ==> bprop weights ==> 30
pylearn2.sandbox.cuda_convnet(partial_sum=None) ==> fprop ==> 16
pylearn2.sandbox.cuda_convnet(partial_sum=None) ==> bprop inputs ==> 20
pylearn2.sandbox.cuda_convnet(partial_sum=None) ==> bprop weights ==> 40
pylearn2.sandbox.cuda_convnet(partial_sum=1) ==> fprop ==> 16
pylearn2.sandbox.cuda_convnet(partial_sum=1) ==> bprop inputs ==> 21
pylearn2.sandbox.cuda_convnet(partial_sum=1) ==> bprop weights ==> 28
</code></pre></div></div>
<p>This time, the FFT-based convolution is faster, but the truly optimal choice would be combining it with cuda-convnet.</p>
<p>We see that the meta-optimizer should not just cherry-pick a different implementation per convolutional layer, but even a different implementation for each of the three convolutions in a layer – something that was not possible in Theano before (nor in any other library I am aware of).</p>
<h2 id="the-swapping-trick">The “swapping trick”</h2>
<p>As you recall, cuda-convnet, Caffe and cuDNN come with specialized algorithms for the three convolutions per layer.
Interestingly, when porting the gemm-based convolution from Caffe to Theano, I noticed that the effort I put in properly using its two backward pass algorithms when applicable did not always pay off: For some configurations, it was faster to just use the forward pass algorithm instead, transposing tensors as needed.
I thus added <a href="https://github.com/Theano/Theano/blob/1477ded8740636c381076b8720055d6c2be64590/theano/sandbox/cuda/opt.py#L1372-1400">a shape-based heuristic</a> to select the fastest algorithm for the gemm-based convolution (making Theano’s port faster than Caffe for some configurations).</p>
<p>When adding support for Nvidia’s cuDNN library, Arnaud understandably assumed that it would hide this complexity from the user and select the optimal algorithm internally. So at first, Theano did not tell cuDNN whether a particular convolution’s purpose was a forward pass or one of the backward passes. When I <a href="https://github.com/Theano/Theano/pull/2273">changed the implementation</a> accordingly, I again noticed that while performance generally improved a lot, for some configurations, using the “wrong” algorithm was actually faster.</p>
<p>Just as for Caffe, we can use this knowledge to be faster than cuDNN.
As the implementation is unknown, we cannot easily define a heuristic for choosing between the cuDNN algorithms.
However, the meta-optimizer can just try all applicable algorithms and see which one is the fastest.
I found it to suffice to just try two algorithms per convolution:</p>
<ul>
<li>For the forward pass, try the “correct” algorithm and the gradient wrt. weights (both are valid convolutions)</li>
<li>For the gradient wrt. weights, try the “correct” algorithm and the forward pass</li>
<li>For the gradient wrt. inputs, try the “correct” algorithm and the forward pass (with additional zero padding to make it a full convolution)</li>
</ul>
<p>I call this the “swapping trick” because it often leads to the first two algorithms being swapped.</p>
<h2 id="implementation">Implementation</h2>
<p>To understand why Theano was a perfect fit to add automatic algorithm selection, we will need to explain a bit of its inner workings.</p>
<p>First, Theano is not a neural network library, but a mathematical expression compiler.
In contrast to, say, Caffe, its basic components are not neural network layers, but mathematical operations.
Implementing a neural network is done by composing the expression for the forward pass (which will probably include matrix multiplications, vector additions, elementwise nonlinearities and possibly batched convolution and pooling), using this to build an expression for the training cost, and then letting Theano transform it into expressions for the gradients wrt. the parameters to be learned.
Finally, the expressions are compiled into functions that evaluate them for specific settings of the free variables (such as a mini-batch of training data).</p>
<p>But right before an expression is compiled, it is <em>optimized</em>, and this is where all the magic happens.
The expression is represented as a graph of Apply nodes (operations) and Variable nodes (the inputs and outputs of an operation), and Theano comes with a bunch of <em>graph optimizers</em> that modify the graph to produce the same result either more efficiently or more numerically stable.
<br />
One particular graph optimizer moves convolution operations from the CPU to the GPU by replacing the respective Apply node and adding the necessary transfer operations around it.
A whole set of graph optimizers then replaces the legacy GPU convolution operation with one of the more efficient implementations available in Theano. These optimizers have relative priorities and can be enabled and disabled by the user.</p>
<p>The new meta-optimizer is just another graph optimizer with a twist: When it encounters a convolution operation, it applies each of the set of available graph optimizers (plus the cuDNN “swapping trick” optimizer) in sequence, each time compiling and executing the subgraph performing the convolution, and chooses the one resulting in the best performance.
(Finally, this explains why it’s called <em>meta</em>-optimization.)
<br />
As the basic components in Theano are the mathematical operations, there is no extra work needed to be able to choose different implementations for the three convolutions per layer: All Theano sees when optimizing and compiling an expression is a graph containing several anonymous convolution operations, so it will naturally optimize each of them separately.</p>
<h2 id="practical-gains">Practical gains</h2>
<p>Let us now put the meta-optimizer to test using the benchmark script mentioned in the cherry-picking section:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>$ THEANO_FLAGS=metaopt.verbose=1 SKIP=legacy,gemm,fft,convnet,dnn python pylearn2_benchmark.py i128x36x12,k64x6x3,b256
Using gpu device 0: GeForce GTX 780 Ti
CONFIG: input = 128 x 36 x 12 * ker = 128 x 64 x 6 x 3 ( bs = 256 , stride = 1 )
ConvMetaOptimizer meta-optimizing GpuConv{valid, (1, 1), None, (3, 6), True, (128, 12, 36), (3, 6)}(GpuFromHost.0, GpuFromHost.0) (5 choices):
* local_conv_fft_full: not applicable
* local_conv_fft_valid: 0.012958 sec
* local_conv_dnn: 0.021169 sec
* local_conv_gemm: 0.03973 sec
* local_conv_dnn_alternative: 0.044379 sec
= local_conv_fft_valid
(experimental) meta-optimizer ==> fprop ==> 12
ConvMetaOptimizer meta-optimizing GpuConv{full, (1, 1), None, (3, 6), True, (64, 10, 31), (3, 6)}(GpuFromHost.0, GpuFromHost.0) (5 choices):
* local_conv_fft_full: 0.019099 sec
* local_conv_fft_valid: not applicable
* local_conv_dnn: 0.032979 sec
* local_conv_gemm: 0.028478 sec
* local_conv_dnn_alternative: 0.015099 sec
= local_conv_dnn_alternative
(experimental) meta-optimizer ==> bprop inputs ==> 15
ConvMetaOptimizer meta-optimizing GpuConv{valid, (1, 1), None, (10, 31), False, (256, 12, 36), (10, 31)}(GpuFromHost.0, GpuFromHost.0) (5 choices):
* local_conv_fft_full: not applicable
* local_conv_fft_valid: 0.011441 sec
* local_conv_dnn: 0.030338 sec
* local_conv_gemm: 0.025984 sec
* local_conv_dnn_alternative: 0.031552 sec
= local_conv_fft_valid
(experimental) meta-optimizer ==> bprop weights ==> 12
</code></pre></div></div>
<p>In verbose mode, the meta-optimizer reports which implementations are tested, how each of them performs and which one is finally chosen.
For the configuration at hands, it turns out that the FFT-based implementation is fastest for the forward pass and the gradient wrt. weights, and cuDNN is fastest for the gradient wrt. inputs – but only when using the “wrong” algorithm for it (namely, cuDNN’s forward pass algorithm with zero padding, tried according to the swapping trick).
In all three instances, the optimal algorithm is about twice as fast as just choosing cuDNN, which would have been Theano’s current default behavior.</p>
<p>When training a full network, the impact will generally be smaller, because the convolution operations only constitute a part of the expressions evaluated (but often the most costly part).
The improvement also heavily depends on the input and kernel shapes – for a wide range of configurations, just using cuDNN for all convolutions is nearly optimal.
Still, a colleague of Sander reported a threefold performance improvement for a network trained for a Kaggle competition, with the meta-optimizer combining FFT, Caffe, and cuDNN with and without the swapping trick.
<!-- For the curious: the configurations were `i1x104x104,k128x9x9,b8` and `i128x96x96,k16x1x1,b8`. --></p>
<p>To get an estimate on how much Theano could help for your use case, just run <a href="https://github.com/soumith/convnet-benchmarks/tree/master/theano">the benchmark script</a> for the configurations occurring in a forward pass through your network.
If you already use Theano, just set <code class="language-plaintext highlighter-rouge">THEANO_FLAGS=optimizer_including=conv_meta</code> to rest assured you will always make the most out of the time (and electricity!) you spend on training your networks.</p>
<h2 id="future">Future</h2>
<p>While the basic machinery is in place and works fine, there are a lot of conceivable improvements:</p>
<ul>
<li>The meta-optimizer should cache its results on disk to speed up repeated compilations of the same graph.</li>
<li>Right now, the meta-optimizer uses all available convolution operations in Theano; it should be possible to control this.</li>
<li>As cuda-convnet is not included in Theano, but an external project (Pylearn2), it is not included in the meta-optimizer. However, it is possible to register additional optimizers at runtime via <code class="language-plaintext highlighter-rouge">theano.sandbox.cuda.opt.conv_metaopt.register()</code>. It would be nice to write such a pluggable optimizer for cuda-convnet.</li>
<li>Similarly, it would be nice to have a wrapper for cuda-convnet2 (in a separate repository) along with an optimizer to be registered with the meta-optimizer.</li>
<li>Currently, meta-optimization can only be used for non-strided valid or full convolutions, because this is what the legacy implementation is limited to. Changing this would require <a href="https://github.com/Theano/Theano/issues/2268#issuecomment-63621626">some refactoring</a>, but lead to cleaner code and slightly improved performance.</li>
<li>Finally, it could be worthwhile to repeat the same for the pooling operation of CNNs: Port additional implementations to Theano, benchmark them and add a meta-optimizer.</li>
</ul>
<p>Watch <a href="https://github.com/Theano/Theano/issues/2072">Issue #2072</a> on github for any progress on this, or even better, step in and implement one of these features if you can use it!
Both that issue and <a href="https://groups.google.com/forum/#!forum/theano-dev">theano-dev</a> are well-suited to ask for hints about implementing any of these TODOs – we’d be glad to have you on board.</p>Guest post: Jan Schlüter from the OFAI, a fellow MIR researcher I have met at several conferences, recently added a feature to Theano that fits so well with my previous two posts on fast convolutions that we decided to include his writeup on my blog. So enjoy the third part of the series, written by Jan!Recommending music on Spotify with deep learning2014-08-05T00:00:00+01:002014-08-05T00:00:00+01:00https://benanne.github.io/2014/08/05/spotify-cnns<p>This summer, I’m interning at <a href="https://www.spotify.com">Spotify</a> in New York City, where I’m working on content-based music recommendation using convolutional neural networks. In this post, I’ll explain my approach and show some preliminary results.</p>
<h2 id="overview">Overview</h2>
<p>This is going to be a long post, so here’s an overview of the different sections. If you want to skip ahead, just click the section title to go there.</p>
<ul>
<li><em><a href="#collaborativefiltering">Collaborative filtering</a></em><br />A very brief introduction, its virtues and its flaws.</li>
<li><em><a href="#contentbased">Content-based recommendation</a></em><br />What to do when no usage data is available.</li>
<li><em><a href="#predicting">Predicting listening preferences with deep learning</a></em><br />Music recommendation based on audio signals.</li>
<li><em><a href="#scalingup">Scaling up</a></em><br />Some details about the convnets I’ve been training at Spotify.</li>
<li><em><a href="#analysis">Analysis: what is it learning?</a></em><br />A look at what the convnets learn about music, with <strong>lots of audio examples</strong>, yay!</li>
<li><em><a href="#whatwill">What will this be used for?</a></em><br />Some potential applications of my work.</li>
<li><em><a href="#futurework">Future work</a></em></li>
<li><em><a href="#conclusion">Conclusion</a></em></li>
</ul>
<div style="float: right;"><a href="https://www.spotify.com/"><img src="/images/spotifylogo.jpg" alt="Spotify" /></a></div>
<h2 id="collaborative-filtering"><a name="collaborativefiltering"><a>Collaborative filtering</a></a></h2>
<p>Traditionally, Spotify <a href="http://www.slideshare.net/MrChrisJohnson/algorithmic-music-recommendations-at-spotify">has relied mostly on collaborative filtering approaches</a> to power their recommendations. The idea of <a href="http://en.wikipedia.org/wiki/Collaborative_filtering">collaborative filtering</a> is to <strong>determine the users’ preferences from historical usage data</strong>. For example, if two users listen to largely the same set of songs, their tastes are probably similar. Conversely, if two songs are listened to by the same group of users, they probably sound similar. This kind of information can be exploited to make recommendations.</p>
<p>Pure collaborative filtering approaches do not use any kind of information about the items that are being recommended, except for the consumption patterns associated with them: they are <strong>content-agnostic</strong>. This makes these approaches widely applicable: the same type of model can be used to recommend books, movies or music, for example.</p>
<p>Unfortunately, this also turns out to be their biggest flaw. Because of their reliance on usage data, popular items will be much easier to recommend than unpopular items, as there is more usage data available for them. This is usually the opposite of what we want. For the same reason, the recommendations can often be rather boring and predictable.</p>
<p>Another issue that is more specific to music, is the <strong>heterogeneity of content with similar usage patterns</strong>. For example, users may listen to entire albums in one go, but albums may contain intro tracks, outro tracks, interludes, cover songs and remixes. These items are atypical for the artist in question, so they aren’t good recommendations. Collaborative filtering algorithms will not pick up on this.</p>
<p>But perhaps the biggest problem is that <strong>new and unpopular songs cannot be recommended</strong>: if there is no usage data to analyze, the collaborative filtering approach breaks down. This is the so-called <strong>cold-start problem</strong>. We want to be able to recommend new music right when it is released, and we want to tell listeners about awesome bands they have never heard of. To achieve these goals, we will need to use a different approach.</p>
<h2 id="content-based-recommendation"><a name="contentbased"></a>Content-based recommendation</h2>
<p>Recently, Spotify has shown considerable interest in incorporating other sources of information into their recommendation pipeline to mitigate some of these problems, as evidenced by their acquisition of music intelligence platform company <a href="http://the.echonest.com/">The Echo Nest</a> a few months back. There are many different kinds of information associated with music that could aid recommendation: tags, artist and album information, lyrics, text mined from the web (reviews, interviews, …), and the audio signal itself.</p>
<p>Of all these information sources, the audio signal is probably the most difficult to use effectively. There is quite a large <strong>semantic gap</strong> between music audio on the one hand, and the various aspects of music that affect listener preferences on the other hand. Some of these are fairly easy to extract from audio signals, such as the genre of the music and the instruments used. Others are a little more challenging, such as the mood of the music, and the year (or time period) of release. A couple are practically impossible to obtain from audio: the geographical location of the artist and lyrical themes, for example.</p>
<p>Despite all these challenges, it is clear that the actual <em>sound</em> of a song will play a very big role in determining whether or not you enjoy listening to it - so it seems like a good idea to try to predict who will enjoy a song by analyzing the audio signal.</p>
<h2 id="predicting-listening-preferences-with-deep-learning"><a name="predicting"></a>Predicting listening preferences with deep learning</h2>
<p>In December last year, my colleague Aäron van den Oord and I published a paper on this topic at NIPS, titled <strong>‘<a href="https://papers.nips.cc/paper/5004-deep-content-based-music-recommendation">Deep content-based music recommendation</a>‘</strong>. We tried to tackle the problem of predicting listening preferences from audio signals by training a regression model to predict the <strong>latent representations</strong> of songs that were obtained from a collaborative filtering model. This way, we could predict the representation of a song in the collaborative filtering space, even if no usage data was available. (As you can probably infer from the title of the paper, the regression model in question was a deep neural network.)</p>
<p>The underlying idea of this approach is that many collaborative filtering models work by projecting both the listeners and the songs into a shared low-dimensional <strong>latent space</strong>. The position of a song in this space encodes all kinds of information that affects listening preferences. If two songs are close together in this space, they are probably similar. If a song is close to a user, it is probably a good recommendation for that user (provided that they haven’t heard it yet). If we can predict the position of a song in this space from audio, we can recommend it to the right audience without having to rely on historical usage data.</p>
<p>We visualized this in the paper by projecting the predictions of our model in the latent space down to two dimensions using the <a href="http://homepage.tudelft.nl/19j49/t-SNE.html">t-SNE algorithm</a>. As you can see below on the resulting map, similar songs cluster together. Rap music can be found mostly in the top left corner, whereas electronic artists congregate at the bottom of the map.</p>
<figure>
<a href="/images/prentje_nips.png"><img src="/images/prentje_nips.png" alt="t-SNE visualization of user listening patterns predicted from audio." /></a>
<figcaption>t-SNE visualization of the latent space (middle). A few close-ups show artists whose songs are projected in specific areas. Taken from <i><a href="https://papers.nips.cc/paper/5004-deep-content-based-music-recommendation">Deep content-based music recommendation</a>, Aäron van den Oord, Sander Dieleman and Benjamin Schrauwen, NIPS 2013.</i></figcaption>
</figure>
<h2 id="scaling-up"><a name="scalingup"></a>Scaling up</h2>
<p>The deep neural network that we trained for the paper consisted of two convolutional layers and two fully connected layers. The input consisted of spectrograms of 3 second fragments of audio. To get a prediction for a longer clip, we just split it up into 3 second windows and averaged the predictions across these windows.</p>
<p>At Spotify, I have access to a larger dataset of songs, and a bunch of different latent factor representations obtained from various collaborative filtering models. They also got me a nice GPU to run my experiments on. This has allowed me to scale things up quite a bit. I am currently training convolutional neural networks (convnets) with 7 or 8 layers in total, using much larger intermediate representations and many more parameters.</p>
<h3 id="architecture">Architecture</h3>
<p>Below is an example of an architecture that I’ve tried out, which I will describe in more detail. It has four convolutional layers and three dense layers. As you will see, there are some important differences between convnets designed for audio signals and their more traditional counterparts used for computer vision tasks.</p>
<p><strong>Warning: gory details ahead! Feel free to skip ahead to ‘Analysis’ if you don’t care about things like ReLUs, max-pooling and minibatch gradient descent.</strong></p>
<figure>
<a href="/images/spotify_convnet.png"><img src="/images/spotify_convnet.png" alt="One of the convolutional neural network architectures I've tried out." /></a>
<figcaption>One of the convolutional neural network architectures I've tried out for latent factor prediction. The time axis (which is convolved over) is vertical.</figcaption>
</figure>
<p>The input to the network consists of <strong>mel-spectrograms</strong>, with 599 frames and 128 frequency bins. A mel-spectrograms is a kind of <strong>time-frequency representation</strong>. It is obtained from an audio signal by computing the Fourier transforms of short, overlapping windows. Each of these Fourier transforms constitutes a <em>frame</em>. These successive frames are then concatenated into a matrix to form the spectrogram. Finally, the frequency axis is changed from a linear scale to a <a href="http://en.wikipedia.org/wiki/Mel_scale">mel scale</a> to reduce the dimensionality, and the magnitudes are scaled logarithmically.</p>
<p>The <strong>convolutional layers</strong> are displayed as red rectangles delineating the shape of the filters that slide across their inputs. They have rectified linear units (ReLUs, with activation function <code class="language-plaintext highlighter-rouge">max(0, x)</code>). Note that all these convolutions are <strong>one-dimensional</strong>; the convolution happens only in the time dimension, not in the frequency dimension. Although it is technically possible to convolve along both axes of the spectrogram, I am not currently doing this. It is important to realize that the two axes of a spectrogram have different meanings (time vs. frequency), which is not the case for images. As a result, it doesn’t really make sense to use square filters, which is what is typically done in convnets for image data.</p>
<p>Between the convolutional layers, there are <strong>max-pooling operations</strong> to downsample the intermediate representations in time, and to add some time invariance in the process. These are indicated with ‘<strong>MP</strong>’. As you can see I used a filter size of 4 frames in every convolutional layer, with max-pooling with a pool size of 4 between the first and second convolutional layers (mainly for performance reasons), and with a pool size of 2 between the other layers.</p>
<p>After the last convolutional layer, I added a <strong>global temporal pooling layer</strong>. This layer pools across the entire time axis, effectively computing statistics of the learned features across time. I included three different pooling functions: the mean, the maximum and the L2-norm.</p>
<p>I did this because the absolute location of features detected in the audio signal is not particularly relevant for the task at hand. This is not the case in image classification: in an image, it can be useful to know roughly where a particular feature was detected. For example, a feature detecting clouds would be more likely to activate for the top half of an image. If it activates in the bottom half, maybe it is actually detecting a sheep instead. For music recommendation, we are typically only interested in the overall presence or absence of certain features in the music, so it makes sense to perform pooling across time.</p>
<p>Another way to approach this problem would be to train the network on short audio fragments, and average the outputs across windows for longer fragments, as we did in the NIPS paper. However, incorporating the pooling into the model seems like a better idea, because it allows for this step to be taken into account during learning.</p>
<p>The globally pooled features are fed into a series of <strong>fully-connected layers</strong> with 2048 rectified linear units. In this network, I have two of them. The last layer of the network is the <strong>output layer</strong>, which predicts 40 latent factors obtained from the <a href="http://erikbern.com/?p=396"><strong>vector_exp</strong> algorithm</a>, one of the various collaborative filtering algorithms that are used at Spotify.</p>
<h3 id="training">Training</h3>
<p>The network is trained to minimize the <strong>mean squared error</strong> (MSE) between the latent factor vectors from the collaborative filtering model and the predictions from audio. These vectors are first normalized so they have a unit norm. This is done to reduce the influence of song popularity (the norms of latent factor vectors tend to be correlated with song popularity for many collaborative filtering models). Dropout is used in the dense layers for regularization.</p>
<p>The dataset I am currently using consists of mel-spectrograms of 30 second excerpts extracted from the middle of the 1 million most popular tracks on Spotify. I am using about half of these for training (0.5M), about 5000 for online validation, and the remainder for testing. During training, the data is augmented by slightly cropping the spectrograms along the time axis with a random offset.</p>
<p>The network is implemented in <strong><a href="http://www.deeplearning.net/software/theano/">Theano</a></strong>, and trained using minibatch gradient descent with Nesterov momentum on a NVIDIA GeForce GTX 780Ti GPU. Data loading and augmentation happens in a separate process, so while the GPU is training on a chunk of data, the next one can be loaded in parallel. About 750000 gradient updates are performed in total. I don’t remember exactly how long this particular architecture took to train, but all of the ones I’ve tried have taken between 18 and 36 hours.</p>
<h3 id="variations">Variations</h3>
<p>As I mentioned before, this is just one example of an architecture that I’ve tried. Some other things I have tried / will try include:</p>
<ul>
<li><a href="/images/moar.jpg">More layers!</a></li>
<li>Using maxout units instead of rectified linear units.</li>
<li>Using stochastic pooling instead of max-pooling.</li>
<li>Incorporating L2 normalization into the output layer of the network.</li>
<li>Data augmentation by stretching or compressing the spectrograms across time.</li>
<li>Concatenating multiple latent factor vectors obtained from different collaborative filtering models.</li>
</ul>
<p>Here are some things that didn’t work quite as well as I’d hoped:</p>
<ul>
<li>Adding ‘bypass’ connections from all convolutional layers to the fully connected part of the network, with global temporal pooling in between. The underlying assumption was that statistics about low-level features could also be useful for recommendation, but unfortunately this hampered learning too much.</li>
<li>Predicting the conditional variance of the factors as in <a href="http://eprints.aston.ac.uk/373/">mixture density networks</a>, to get confidence estimates for the predictions and to identify songs for which latent factor prediction is difficult. Unfortunately this seemed to make training quite a lot harder, and the resulting confidence estimates did not behave as expected.</li>
</ul>
<h2 id="analysis-what-is-it-learning"><a name="analysis"></a>Analysis: what is it learning?</h2>
<p>Now for the cool part: <strong>what are these networks learning? What do the features look like?</strong> The main reason I chose to tackle this problem with convnets, is because I believe that music recommendation from audio signals is a pretty complex problem bridging many levels of abstraction. My hope was that successive layers of the network would learn progressively more complex and invariant features, as they do for image classification problems.</p>
<p>It looks like that’s exactly what is happening. First, let’s take a look at the first convolutional layer, which learns a set of filters that are applied directly to the input spectrograms. These filters are easy to visualize. They are shown in the image below. Click for a high resolution version (5584x562, ~600kB). Negative values are red, positive values are blue and white is zero. Note that each filter is only four frames wide. The individual filters are separated by dark red vertical lines.</p>
<figure>
<a href="/images/filters_hires.png"><img src="/images/filters_lores.png" alt="Filters learned in the first convolutional layer." /></a>
<figcaption>Visualization of the filters learned in the first convolutional layer. The time axis is horizontal, the frequency axis is vertical (frequency increases from top to bottom). Click for a high resolution version (5584x562, ~600kB).</figcaption>
</figure>
<p>From this representation, we can see that a lot of the filters pick up harmonic content, which manifests itself as parallel red and blue bands at different frequencies. Sometimes, these bands are are slanted up or down, indicating the presence of rising and falling pitches. It turns out that these filters tend to detect human voices.</p>
<h3 id="playlists-for-low-level-features-maximal-activation">Playlists for low-level features: maximal activation</h3>
<p><strong>To get a better idea of what the filters learn, I made some playlists with songs from the test set that maximally activate them.</strong> Below are a few examples. There are 256 filters in the first layer of the network, which I numbered from 0 to 255. Note that this numbering is arbitrary, as they are unordered.</p>
<p>These four playlists were obtained by finding songs that maximally activate a given filter in the 30 seconds that were analyzed. I selected a few interesting looking filters from the first convolutional layer and computed the feature representations for each of these, and then searched for the maximal activations across the entire test set. <strong>Note that you should listen to the middle of the tracks to hear what the filters are picking up on, as this is the part of the audio signal that was analyzed.</strong></p>
<p>All of the Spotify playlists below should have 10 tracks. Some of them may not be available in all countries due to licensing issues.</p>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 14: vibrato singing</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:3KcOA1o1Q1E7AKJp6lRKOG" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 242: ambience</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:4XSo1qV9vGNwM7uz8LYwDR" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 250: vocal thirds</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:2CRSJ4h9cWvDSwNoqh9UJC" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 253: bass drums</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:121nUXz11tA96ONF4Dk2Eh" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="clear: both;"></div>
<figure>
<a href="/images/feature_closeup_max.png"><img style="display: block; margin-left: auto; margin-right: auto; height: 300px;" src="/images/feature_closeup_max.png" alt="Closeup of filters 14, 242, 250 and 253." /></a>
<figcaption style="text-align: center;">Closeup of filters 14, 242, 250 and 253. Click for a larger version.</figcaption>
</figure>
<ul>
<li>Filter 14 seems to pick up <strong>vibrato singing</strong>.</li>
<li>Filter 242 picks up some kind of <strong>ringing ambience</strong>.</li>
<li>Filter 250 picks up <strong>vocal thirds</strong>, i.e. multiple singers singing the same thing, but the notes are a major third (4 semitones) apart.</li>
<li>Filter 253 picks up various types of <strong>bass drum sounds</strong>.</li>
</ul>
<p>The genres of the tracks in these playlists are quite varied, which indicates that these features are picking up mainly low-level properties of the audio signals.</p>
<h3 id="playlists-for-low-level-features-average-activation">Playlists for low-level features: average activation</h3>
<p>The next four playlists were obtained in a slightly different way: I computed the <strong>average activation of each feature across time</strong> for each track, and then found the maximum across those. This means that for these playlists, the filter in question is constantly active in the 30 seconds that were analyzed (i.e. it’s not just one ‘peak’). This is more useful for detecting harmonic patterns.</p>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 1: noise, distortion</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:0vRgR1RreMyOyZAnfF5Qzr" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 2: pitch (A, Bb)</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:38jOUhgqrT9Fw4xcFxnTCf" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 4: drones</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:2wYoHEQ082aw3sI4CWqwsi" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 28: chord (A, Am)</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:2s1TOuZb5lGoa30BFucE5K" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="clear: both;"></div>
<figure>
<a href="/images/feature_closeup_mean.png"><img style="display: block; margin-left: auto; margin-right: auto; height: 300px;" src="/images/feature_closeup_mean.png" alt="Closeup of filters 1, 2, 4 and 28." /></a>
<figcaption style="text-align: center;">Closeup of filters 1, 2, 4 and 28. Click for a larger version.</figcaption>
</figure>
<ul>
<li>Filter 1 picks up <strong>noise</strong> and (guitar) <strong>distortion</strong>.</li>
<li>Filter 2 seems to pick up a <strong>specific pitch: a low Bb</strong>. It also picks up A sometimes (a semitone below) because the frequency resolution of the mel-spectrograms is not high enough to distinguish them.</li>
<li>Filter 4 picks up various low-pitched <strong>drones</strong>.</li>
<li>Filter 28 picks up the <strong>A chord</strong>. It seems to pick up both the minor and major versions, so it might just be detecting the pitches A and E (the fifth).</li>
</ul>
<p>I thought it was very interesting that the network is learning to detect specific pitches and chords. I had previously assumed that the exact pitches or chords occurring in a song would not really affect listener preference. I have two hypotheses for why this might be happening:</p>
<ul>
<li>The network is just learning to detect <strong>harmonicity</strong>, by learning various filters for different kinds of harmonics. These are then pooled together at a higher level to detect harmonicity across different pitches.</li>
<li>The network is learning that some <strong>chords and chord progressions</strong> are more common than others in certain genres of music.</li>
</ul>
<p>I have not tried to verify either of these, but it seems like the latter would be pretty challenging for the network to pick up on, so I think the former is more likely.</p>
<h3 id="playlists-for-high-level-features">Playlists for high-level features</h3>
<p>Each layer in the network takes the feature representation from the layer below, and extracts a set of higher-level features from it. <strong>At the topmost fully-connected layer of the network, just before the output layer, the learned filters turn out to be very selective for certain subgenres</strong>. For obvious reasons, it is non-trivial to visualize what these filters pick up on at the spectrogram level. Below are six playlists with songs from the test set that maximally activate some of these high-level filters.</p>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 3: christian rock</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:48cTKQ3ubm8osDVyrML9p9" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 15: choirs / a cappella + smooth jazz</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:7tQShxnjMycjb0NsQJuU8f" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 26: gospel</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:4ZRJs9g1GzqVRE6wh1YNdT" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 37: Chinese pop</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:25GNJ9ygtFxt1j3jWwdoYv" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 49: chiptune, 8-bit</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:0jSEMRLNNNVvOkPHuDUAYM" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center;">Filter 1024: deep house</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:3cS6tiQOpW0eXsyAlxcrqO" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="clear: both;"></div>
<p>It is clear that each of these filters is identifying specific genres. Interestingly some filters, like #15 for example, seem to be <strong>multimodal</strong>: they activate strongly for two or more styles of music, and those styles are often completely unrelated. Presumably the output of these filters is disambiguated when viewed in combination with the activations of all other filters.</p>
<p>Filter 37 is interesting because it almost seems like it is <strong>identifying the Chinese language</strong>. This is not entirely impossible, since the phoneme inventory of Chinese is quite distinct from other languages. A few other seemingly language-specific filters seem to be learned: there is one that detects rap music in Spanish, for example. Another possibility is that Chinese pop music has some other characteristic that sets it apart, and the model is picking up on that instead.</p>
<p>I spent some time analyzing the first 50 or so filters in detail. Some other filter descriptions I came up with are: lounge, reggae, darkwave, country, metalcore, salsa, Dutch and German carnival music, children’s songs, dance, vocal trance, punk, Turkish pop, and my favorite, ‘exclusively Armin van Buuren’. Apparently he has so many tracks that he gets his own filter.</p>
<p>The filters learned by <a href="http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks">Alex Krizhevsky’s ImageNet network</a> have been reused for various other computer vision tasks with great success. Based on their diversity and invariance properties, it seems that these filters learned from audio signals may also be useful for other music information retrieval tasks besides predicting latent factors.</p>
<h3 id="similarity-based-playlists">Similarity-based playlists</h3>
<p>Predicted latent factor vectors can be used to find songs that sound similar. Below are a couple of playlists that were generated by predicting the factor vector for a given song, and then <strong>finding other songs in the test set whose predicted factor vectors are close to it</strong> in terms of cosine distance. As a result, the first track in the playlist is always the query track itself.</p>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center; width: 290px; height: 50px;">The Notorious B.I.G. - Juicy (hip hop)</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:6wnPsncVsmApMRj5g7PWkz" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center; width: 290px; height: 50px;">Cloudkicker - He would be riding on the subway... (post-rock, post-metal)</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:6K9Df3nXsZVftKmYliUcIS" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center; width: 290px; height: 50px;">Architects - Numbers Count For Nothing (metalcore, hardcore)</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:5bAhiBYDlsCkr2BzKKIS1b" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center; width: 290px; height: 50px;">Neophyte - Army of Hardcore (hardcore techno, gabber)</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:3lwuKlfrqnqxaYKF10KOT4" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="clear: both;"></div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center; width: 290px; height: 50px;">Fleet Foxes - Sun It Rises (indie folk)</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:0aFDjurTiPd3VO8RO8T6uz" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="float: left; margin: .5em;">
<figcaption style="text-align: center; width: 290px; height: 50px;">John Coltrane - My Favorite Things (jazz)</figcaption>
<iframe src="https://embed.spotify.com/?uri=spotify:user:sander_dieleman:playlist:66rHmJlejT0PQh74JXV8Ie" width="300" height="380" frameborder="0" allowtransparency="true"></iframe>
</div>
<div style="clear: both;"></div>
<p>Most of the similar tracks are pretty decent recommendations for fans of the query tracks. Of course these lists are far from perfect, but considering that they were obtained based only on the audio signals, the results are pretty decent. One example where things go wrong is the list for ‘My Favorite Things’ by John Coltrane, which features a couple of strange outliers, most notably ‘Crawfish’ by Elvis Presley. This is probably because the part of the audio signal that was analyzed (8:40 to 9:10) contains a crazy sax solo. Analyzing the whole song might give better results.</p>
<h2 id="what-will-this-be-used-for"><a name="whatwill"></a>What will this be used for?</h2>
<p>Spotify already uses a bunch of different information sources and algorithms in their recommendation pipeline, so the most obvious application of my work is simply to include it as an extra signal. However, it could also be used to filter outliers from recommendations made by other algorithms. As I mentioned earlier, collaborative filtering algorithms will tend to include intro tracks, outro tracks, cover songs and remixes in their recommendations. These could be filtered out effectively using an audio-based approach.</p>
<p>One of my main goals with this work is to make it possible to recommend new and unpopular music. I hope that this will help lesser known and up and coming bands, and that it will level the playing field somewhat by enabling Spotify to recommend their music to the right audience. (Promoting up and coming bands also happens to be one of the main objectives of my non-profit website <a href="http://got-djent.com/">got-djent.com</a>.)</p>
<p>Hopefully some of this will be ready for <a href="http://en.wikipedia.org/wiki/A/B_testing">A/B testing</a> soon, so we can find out if these audio-based recommendations actually make a difference in practice. This is something I’m very excited about, as it’s not something you can easily do in academia.</p>
<h2 id="future-work"><a name="futurework"></a>Future work</h2>
<p>Another type of user feedback that Spotify collects are the <strong>thumbs up</strong> and <strong>thumbs down</strong> that users can give to tracks played on radio stations. This type of information is very useful to determine which tracks are similar. Unfortunately, it is also quite noisy. I am currently trying to use this data in a ‘<a href="http://en.wikipedia.org/wiki/Learning_to_rank">learning to rank</a>’ setting. I’ve also been experimenting with various distance metric learning schemes, such as <a href="https://plus.google.com/+YannLeCunPhD/posts/8biVDbVrhAp">DrLIM</a>. If anything cool comes out of that I might write another post about it.</p>
<h2 id="conclusion"><a name="conclusion"></a>Conclusion</h2>
<p>In this post I’ve given an overview of my work so far as a machine learning intern at Spotify. I’ve explained my approach to using convnets for audio-based music recommendation and I’ve tried to provide some insight into what the networks actually learn. For more details about the approach, please refer to the NIPS 2013 paper ‘<a href="https://papers.nips.cc/paper/5004-deep-content-based-music-recommendation">Deep content-based music recommendation</a>’ by Aäron van den Oord and myself.</p>
<p>If you are interested in deep learning, feature learning and its applications to music, have a look at my <a href="http://benanne.github.io/research/">research page</a> for an overview of some other work I have done in this domain. If you’re interested in Spotify’s approach to music recommendation, check out <a href="http://www.slideshare.net/MrChrisJohnson/algorithmic-music-recommendations-at-spotify">these</a> <a href="http://www.slideshare.net/erikbern/music-recommendations-mlconf-2014">presentations</a> on Slideshare and <a href="http://erikbern.com/">Erik Bernhardsson’s blog</a>.</p>
<p>Spotify is a really cool place to work at. They are very open about their methods (and they let me write this blog post), which is not something you come across often in industry. If you are interested in recommender systems, collaborative filtering and/or music information retrieval, and you’re looking for an internship or <a href="https://www.spotify.com/us/jobs/opportunities/">something more permanent</a>, don’t hesitate to get in touch with them.</p>
<p>If you have any questions or feedback about this post, feel free to leave a comment!</p>
<ul>
<li>Post on <a href="https://news.ycombinator.com/item?id=8137264">Hacker News</a></li>
<li>Post on <a href="http://www.reddit.com/r/MachineLearning/comments/2cozew/recommending_music_on_spotify_with_deep_learning/">r/machinelearning</a></li>
<li>Post on the <a href="https://plus.google.com/u/1/+SanderDieleman/posts/Hb3R4YfaANW">Google+ deep learning community</a></li>
<li>Post on the <a href="https://plus.google.com/u/1/+SanderDieleman/posts/H6n9TvMdyKT">Google+ music information retrieval community</a></li>
</ul>
<figure>
<img style="display: block; margin-left: auto; margin-right: auto; width: 600px;" src="/images/deck.jpg" alt="View of NYC from the Spotify deck." />
<figcaption style="text-align: center;">View of NYC from the <a href="http://www.businessinsider.com/take-a-tour-of-spotify-new-nyc-office-2014-7?op=1">Spotify deck</a>.</figcaption>
</figure>This summer, I’m interning at Spotify in New York City, where I’m working on content-based music recommendation using convolutional neural networks. In this post, I’ll explain my approach and show some preliminary results.Slides for my talk at the Deep Learning London meetup2014-05-29T00:00:00+01:002014-05-29T00:00:00+01:00https://benanne.github.io/2014/05/29/slides-meetup<p>Yesterday, I gave talk at the <a href="http://www.meetup.com/Deep-Learning-London/events/183804302/">Deep Learning London Meetup</a> about my <a href="http://benanne.github.io/research/">PhD research</a> and <a href="http://benanne.github.io/2014/04/05/galaxy-zoo.html">my approach</a> to winning the <a href="http://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge">Galaxy Zoo challenge on Kaggle</a>. The slides for my talk are available for download:</p>
<ul>
<li><a href="https://dl.dropboxusercontent.com/u/19706734/music_galaxies.pdf">Download the slides in PDF format <strong>(16MB)</strong></a></li>
<li><a href="https://dl.dropboxusercontent.com/u/19706734/music_galaxies.pptx">Download the slides in PPTX format <strong>(82MB)</strong></a></li>
</ul>
<p>The three papers I discussed in the first part of the talk are described <a href="http://benanne.github.io/research/">here</a>, download links to the PDFs are included. A detailed description of my solution for the Galaxy Challenge is available in <a href="http://benanne.github.io/2014/04/05/galaxy-zoo.html">an earlier post on this blog</a>. The code for all 17 models included in the winning ensemble is available on <a href="https://github.com/benanne/kaggle-galaxies">GitHub</a>.</p>Yesterday, I gave talk at the Deep Learning London Meetup about my PhD research and my approach to winning the Galaxy Zoo challenge on Kaggle. The slides for my talk are available for download:Even faster convolutions in Theano using FFTs2014-05-12T00:00:00+01:002014-05-12T00:00:00+01:00https://benanne.github.io/2014/05/12/fft-convolutions-in-theano<p><a href="http://benanne.github.io/2014/04/03/faster-convolutions-in-theano.html">Last month</a> I wrote about how you can use the cuda-convnet wrappers in pylearn2 to get up to 3x faster GPU convolutions in Theano. Since then I’ve been working on an FFT-based convolution implementation for Theano. Preliminary tests indicate that this approach is again 2-4x faster than the cuda-convnet wrappers.</p>
<p>I wrote the code in pure Python, using <a href="https://github.com/lebedov/scikits.cuda">scikits.cuda</a> and <a href="http://mathema.tician.de/software/pycuda/">PyCUDA</a> to do the heavy lifting. The Theano team is <a href="https://groups.google.com/forum/#!topic/theano-users/6xiFFpBBDq0">currently working on integrating this code into Theano</a>. They also plan to create a proper C/CUDA implementation to guarantee the best performance.</p>
<p>I put everything up on GitHub, you can find the code there, or clone it and try it yourself:</p>
<ul>
<li><strong><a href="https://github.com/benanne/theano_fftconv">https://github.com/benanne/theano_fftconv</a></strong></li>
</ul>
<h2 id="fft-based-convolution">FFT-based convolution</h2>
<p>The Fourier transform of a convolution of two functions is the product of the Fourier transforms of those functions. This is the <a href="http://en.wikipedia.org/wiki/Convolution_theorem">convolution theorem</a>. This result can be used to quickly compute convolutions in the Fourier domain, since an elementwise product is much less computationally intensive than a convolution.</p>
<p>However, there is a price to be paid: the inputs need to be transformed using the Fast Fourier Transform (FFT), and the product of these transformed inputs needs to be transformed again using the inverse FFT. Depending on the sizes of the inputs, these costs can be pretty significant, so sometimes it is a better idea to just compute the convolution in the original domain.</p>
<p>I was somewhat surprised to learn that all popular implementations of convolutional neural networks (CNNs) use the latter approach, including that of Theano and cuda-convnet. The reason is that typically, convolutions in CNNs involve relatively small filters, so I think people just assumed it wasn’t worth it.</p>
<p>However, a paper published at ICLR 2014 recently caught my eye: <a href="http://openreview.net/document/aa6ab717-ca19-47e1-a958-823b9a106ca9#aa6ab717-ca19-47e1-a958-823b9a106ca9">Fast Training of Convolutional Networks through FFTs</a> by Mathieu, Henaff and LeCun. They implemented the FFT-based approach in the <a href="http://torch.ch/">Torch7 framework</a> and compared its performance to Torch7’s own ‘classical’ implementation. They concluded that it is actually advantageous to use FFT-based convolutions in CNNs in many cases.</p>
<p>The reason is actually quite straightforward: compared to the general case, the overhead of computing the FFTs of the inputs is drastically reduced. We need to compute the convolution of each input example in a given minibatch with each filter. If there are <code class="language-plaintext highlighter-rouge">m</code> examples in the minibatch with <code class="language-plaintext highlighter-rouge">k</code> input channels, and <code class="language-plaintext highlighter-rouge">n</code> filters, this means we need to compute <code class="language-plaintext highlighter-rouge">m * n * k</code> convolutions. In the Fourier domain, this turns into <code class="language-plaintext highlighter-rouge">m * n * k</code> elementwise products. However, <strong>we only need to compute the FFT of each input example and each filter once</strong>. So the total number of FFTs to compute is not <code class="language-plaintext highlighter-rouge">2 * m * n * k</code>, but <code class="language-plaintext highlighter-rouge">(m + n) * k</code>.</p>
<p>But that’s not everything: the output of a convolutional layer in a CNN is actually a sum of convolutions across all <code class="language-plaintext highlighter-rouge">k</code> input channels. Because the FFT is a linear operator, we can compute this sum in the Fourier domain, and then take the IFFT of this sum (instead of the other way around). This means we only need to compute <code class="language-plaintext highlighter-rouge">m * n</code> IFFTs, instead of <code class="language-plaintext highlighter-rouge">m * n * k</code>. It turns out that these savings can be very significant.</p>
<h2 id="a-cudac-less-theano-implementation">A CUDA/C-less Theano implementation</h2>
<p>So this got me thinking that it should be possible to do the same thing in Theano. Theano already intelligently replaces convolution operators in computational graphs with their GPU-based counterparts in the optimization phase. If an FFT-based implementation was added, it could do the same with that version instead.</p>
<p>I set out to implement this, but unfortunately my knowledge of CUDA is nonexistent, and my knowledge of C can be called rusty at best. So I sought to avoid both. Enter <a href="https://github.com/lebedov/scikits.cuda">scikits.cuda</a>, which offers all the necessary primitives: forward and inverse FFTs, and complex products (the FFT of a real signal is complex and symmetric).</p>
<p>Luckily, scikits.cuda is built on top of <a href="http://mathema.tician.de/software/pycuda/">PyCUDA</a>, and the Theano docs have some examples of how to implement PyCUDA-based operators. Essentially I just had to glue everything together.</p>
<h2 id="implementation-details">Implementation details</h2>
<p>As mentioned earlier, an FFT-based convolution can be broken up into 3 parts: an FFT of the input images and the filters, a bunch of elementwise products followed by a sum across input channels, and then an IFFT of the outputs. I decided to implement each of these as a separate Theano operator. That way, the optimizer could detect if the same inputs or filters are used in multiple convolutions, and only compute them once. At the moment I’m still unsure whether this is beneficial - perhaps some additional performance could be gained by combining everything into a single, monolithic FFT-convolution operator. But that’s a discussion for another time.</p>
<p>The FFT and IFFT operators were the easiest. scikits.cuda exposes a nice API to perform <strong>batched FFTs</strong>. This allows for GPU-parallelism to be exploited when many FFTs of the same size have to be computed. This is precisely our use case. The API uses the cuFFT implementation internally, which is a part of CUDA.</p>
<p>Interestingly, the authors of the paper I mentioned earlier claim that using cuFFT is not an option because it does not allow to exploit this type of parallelism, so they made their own CUDA FFT implementation instead. However, I got pretty good results using cuFFT, so I don’t know what lead them to make this claim. Perhaps the batched FFT is a recent addition to cuFFT. The same batched approach can be used for the IFFT.</p>
<p>The tough part was performing the actual convolution in the Fourier domain, by computing the complex elementwise products and summing across the input channels. Theano does not have support for complex numbers, so some trickery was required to convert complex arrays into real arrays with an extra trailing dimension of size 2, to contain the real and imaginary parts of the numbers.</p>
<p>I tried a number of different approaches, but what worked best in the end is interpreting the operation as a dot product. A dot product is precisely that: an elementwise product with some broadcasting, followed by summing out a particular dimension. So by reshaping the Fourier-transformed inputs and filters, the multiply-and-sum operation could be translated into a set of dot products. This is great, because GPUs are really good at computing dot products quickly.</p>
<p>It turns out that recent versions of cuBLAS also support <strong>batched dot products</strong>, which offer the same performance advantages as batched FFTs. Since we need to perform a large number of dot products with the same shapes, this was again a perfect match for our use case. The particular function I needed to compute a batched complex-valued dot product is <code class="language-plaintext highlighter-rouge">cublasCgemmBatched</code>. Unfortunately this wasn’t available through scikits.cuda yet, but it wasn’t hard to add the necessary wrappers. I sent a <a href="https://github.com/lebedov/scikits.cuda/pull/52">pull request</a> and it is now included (so make sure to get the latest version of scikits.cuda from git if you want to try this).</p>
<h2 id="proof-of-concept">Proof of concept</h2>
<p>So far I’ve only implemented the <em>valid</em> convolution. Using the implementation in the context of a CNN will also require support for full convolutions - but this is easy to mimic by padding the input with zeros. I have not implemented an optimization that swaps out Theano’s own convolution operator with the FFT-based version, but that is something the Theano team is currently working on.</p>
<p>Preliminary benchmarks show that this implementation is typically faster than cuda-convnet. The table below shows the duration of a single valid convolution computation with the given input and filter shapes, measured on a GeForce GTX 680, averaged across 10 runs, and not taking into account the warmup that the FFT-based implementation requires (the first run will be a bit slower because the FFT plans need to be created).</p>
<p>Following Theano conventions, the input shape is given as <code class="language-plaintext highlighter-rouge">(batch size, number of input channels, width, height)</code> and the filter shape is given as <code class="language-plaintext highlighter-rouge">(number of filters, number of input channels, width, height)</code>. Durations are given for Theano’s own <code class="language-plaintext highlighter-rouge">conv2d</code> implementation, the cuda-convnet wrappers from pylearn2, and the FFT-based implementation. The speedup of the FFT-based implementation over the cuda-convnet wrappers is also given.</p>
<table>
<thead>
<tr>
<th style="text-align: center">input shape</th>
<th style="text-align: center">filter shape</th>
<th style="text-align: center">Theano’s own</th>
<th style="text-align: center">cuda-convnet</th>
<th style="text-align: center">FFT-based</th>
<th style="text-align: center">speedup</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align: center">(64, 3, 96, 96)</td>
<td style="text-align: center">(128, 3, 16, 16)</td>
<td style="text-align: center">388.9 ms</td>
<td style="text-align: center">156.9 ms</td>
<td style="text-align: center">117.3 ms</td>
<td style="text-align: center">1.34x</td>
</tr>
<tr>
<td style="text-align: center">(64, 128, 32, 32)</td>
<td style="text-align: center">(64, 128, 8, 8)</td>
<td style="text-align: center">233.9 ms</td>
<td style="text-align: center">87.4 ms</td>
<td style="text-align: center">27.1 ms</td>
<td style="text-align: center">3.23x</td>
</tr>
<tr>
<td style="text-align: center">(128, 32, 54, 54)</td>
<td style="text-align: center">(64, 32, 6, 6)</td>
<td style="text-align: center">457.5 ms</td>
<td style="text-align: center">107.6 ms</td>
<td style="text-align: center">52.2 ms</td>
<td style="text-align: center">2.06x</td>
</tr>
<tr>
<td style="text-align: center">(128, 128, 16, 16)</td>
<td style="text-align: center">(128, 128, 8, 8)</td>
<td style="text-align: center">133.4 ms</td>
<td style="text-align: center">43.5 ms</td>
<td style="text-align: center">18.6 ms</td>
<td style="text-align: center">2.34x</td>
</tr>
<tr>
<td style="text-align: center">(128, 1024, 32, 32)</td>
<td style="text-align: center">(128, 1024, 4, 4)</td>
<td style="text-align: center">6246.2 ms</td>
<td style="text-align: center">1283.5 ms</td>
<td style="text-align: center">357.8 ms</td>
<td style="text-align: center">3.59x</td>
</tr>
</tbody>
</table>
<p>In all cases we get a nice speedup. This approach seems to be the most beneficial when the number of input channels is large - this makes sense, as this is the dimension that is summed over in the batched dot product. But even when this number is small (e.g. 3) it’s still faster.</p>
<h2 id="try-it-out">Try it out</h2>
<p>As mentioned in the introduction, you can grab the code for this at:</p>
<ul>
<li><strong><a href="https://github.com/benanne/theano_fftconv">https://github.com/benanne/theano_fftconv</a></strong></li>
</ul>
<p>All the relevant code is in the file <a href="https://github.com/benanne/theano_fftconv/blob/master/fftconv.py"><code class="language-plaintext highlighter-rouge">fftconv.py</code></a>. The file <a href="https://github.com/benanne/theano_fftconv/blob/master/cufftop.py"><code class="language-plaintext highlighter-rouge">cufftop.py</code></a> was mainly used for experimentation, and contains some alternative implementations of the multiply-and-sum step.</p>
<p>Note that the latest revision of scikits.cuda is required, to ensure that the <code class="language-plaintext highlighter-rouge">cublasCgemmBatched</code> function is available. You’ll also need a working installation of PyCUDA, as this is a dependency of scikits.cuda. And of course, you’ll need Theano and a working CUDA installation.</p>
<p>If you’re patient, you can also wait until the code is available in Theano. Chances are you’ll be able to use it without modifying your existing code, as they are also building an optimization that will replace Theano’s own convolutions with the FFT-based implementation. And if you’re very patient, you can wait until they build the CUDA/C version, which will eliminate the scikits.cuda and PyCUDA dependencies, and hopefully it will be a bit faster as well due to the reduced overhead.</p>
<p>The code to compute the numbers in the table above is in the file <a href="https://github.com/benanne/theano_fftconv/blob/master/speedtest.py"><code class="language-plaintext highlighter-rouge">speedtest.py</code></a>. This script also checks whether the output of all three implementations is the same (up to a given tolerance). More numbers for different input/filter shapes and different GPUs are welcome, so if you run this script on your own machine(s), feel free to send me the results.</p>
<p>Feedback is welcome, and if you’d like to help with integrating this into Theano, <a href="https://groups.google.com/forum/#!topic/theano-users/6xiFFpBBDq0">join the conversation at the theano-users group</a>!</p>Last month I wrote about how you can use the cuda-convnet wrappers in pylearn2 to get up to 3x faster GPU convolutions in Theano. Since then I’ve been working on an FFT-based convolution implementation for Theano. Preliminary tests indicate that this approach is again 2-4x faster than the cuda-convnet wrappers.