Impact of Hyperparameters

7.5. Impact of Hyperparameters

Most kernels depend on so-called hyperparameters. For example, the length scale \(l\) which appears in many of the previously stated examples, the noise variance \(\sigma_{\text{noise}}^2\) as well as the scaling parameter \(\sigma^2\) mentioned in Combination and Modification of Kernels are hyperparameters. In addition to the genereal choice of the kernel, the exact values of these parameters determine the properties of the underlying Gaussian process. In general, we denote the collection of all hyperparameters of some kernel \(k\) by \(\theta\). Please note that \(\theta\) is possibly vector valued, if \(k\) possesses more than one hyperparameter.

For example, the scaled squared exponential kernel with positive hyperparameters \(l\) and \(\sigma^2\) is given by

\[k(x, x^{\prime}) = \sigma^2~\exp \big(-\frac{r^2}{2~l^2} \big)\]

where \(r = |x - x^{\prime}|\) for \(x, x^{\prime} \in \mathbb{R}^d\). Moreover, noise is variance \(\sigma_{\text{noise}}^2\) is added. Thus, this kernel yields \(\theta = (l, \sigma^2, \sigma_{\text{noise}}^2)\). A larger length scale implies a higher correlation between \(f(x)\) and \(f(x^{\prime})\). Similarly, higher value of \(\sigma^2\) implies a smaller correlation.

Given a test point \(x^*\), the mean prediction and variance in Gaussian process regression depend on \(K(X, X)\) as well as \(K(x^*, X)\) which in turn depend on \(\theta\). Consequently, the hyperparmeters impact the model predictions.

For the case of the scaled squared exponential kernel with optional noise, the impact is shown in the following visualization which can be used in Google Colab:

# Gaussian process posterior
def cond_distr(X1, y1, X2, kernel):
    """
    Calculate the posterior mean and covariance matrix for y2
    based on the corresponding input X2, the observations (y1, X1), 
    and the prior kernel function.
    """
    # Kernel of the observations
    Σ11 = kernel(X1)
    # Kernel of observations vs to-predict
    Σ12 = kernel(X1, X2)
    # Solve
    solved = scipy.linalg.solve(Σ11, Σ12, assume_a='pos').T
    # Compute posterior mean
    μ2 = solved @ y1
    # Compute the posterior covariance
    Σ22 = kernel(X2, X2)
    Σ2 = Σ22 - (solved @ Σ12)
    return μ2[:, 0], Σ2  # mean, covariance

def _plot(mean, cov):
    fig= plt.figure(figsize=(14, 10))
    ax = plt.axes(xlim=(0, 10), ylim=(-2.5, 2.5))
    ax.scatter(X, Y)
    plt.plot(t, np.sin(t), label='true function')
    plt.plot(t, mean, c='purple', label='mean prediction')
    plt.fill_between(t, mean - 1.96 * np.sqrt(np.diag(cov)), mean + 1.96 * np.sqrt(np.diag(cov)), 
                     color='purple', alpha=.25, label='credible interval')
    plt.legend()
    plt.show()
from IPython.display import display, clear_output
!pip install ipympl
clear_output()

%matplotlib inline
from ipywidgets import interact, interact_manual, FloatSlider
import numpy as np
import scipy
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegFileWriter
import seaborn as sns
sns.set_style('darkgrid')
from sklearn.gaussian_process.kernels import RBF
from sklearn.gaussian_process.kernels import ConstantKernel
from sklearn.gaussian_process.kernels import WhiteKernel

style = {'description_width': 'initial'}

# sample training data
X = np.array([1.01, 3.51, 4.51, 7.01, 7.91, 9.01]).reshape(-1, 1)
Y = np.sin(X)
nb_steps = 500
delta_t = 10 / nb_steps
t = np.arange(0, 10, delta_t)

@interact(l=FloatSlider(value=2.05, min=1e-5, max=3, step=0.05, 
                        continuous_update=False), 
          sigma2=FloatSlider(value=1., min=1e-5, max=10, step=0.1, 
                             continuous_update=False),
          sigma2_noise=FloatSlider(value=0., min=0., max=1., step=0.01, 
                                   continuous_update=False, style=style))
def anim_hyper(l, sigma2, sigma2_noise):
    kernel = ConstantKernel(constant_value=sigma2) * RBF(length_scale=l) + WhiteKernel(noise_level=sigma2_noise)
    mean, cov = cond_distr(X, Y, t.reshape(-1, 1), kernel) 
    _plot(mean, cov)