Neural Tangent Kernel
   HOME

TheInfoList



OR:

In the study of artificial neural networks (ANNs), the neural tangent kernel (NTK) is a
kernel Kernel may refer to: Computing * Kernel (operating system), the central component of most operating systems * Kernel (image processing), a matrix used for image convolution * Compute kernel, in GPGPU programming * Kernel method, in machine lea ...
that describes the evolution of deep artificial neural networks during their training by
gradient descent In mathematics, gradient descent (also often called steepest descent) is a first-order iterative optimization algorithm for finding a local minimum of a differentiable function. The idea is to take repeated steps in the opposite direction of ...
. It allows ANNs to be studied using theoretical tools from
kernel methods In machine learning, kernel machines are a class of algorithms for pattern analysis, whose best known member is the support-vector machine (SVM). The general task of pattern analysis is to find and study general types of relations (for example c ...
. In general, a kernel is a positive-semidefinite
symmetric Symmetry (from grc, συμμετρία "agreement in dimensions, due proportion, arrangement") in everyday language refers to a sense of harmonious and beautiful proportion and balance. In mathematics, "symmetry" has a more precise definit ...
function of two inputs which represents some notion of similarity between the two inputs. The NTK is a specific kernel derived from a given neural network; in general, when the neural network parameters change during training, the NTK evolves as well. However, in the limit of large layer width the NTK becomes constant, revealing a duality between training the wide neural network and kernel methods:
gradient descent In mathematics, gradient descent (also often called steepest descent) is a first-order iterative optimization algorithm for finding a local minimum of a differentiable function. The idea is to take repeated steps in the opposite direction of ...
in the infinite-width limit is fully equivalent to kernel gradient descent with the NTK. As a result, using gradient descent to minimize least-square loss for neural networks yields the same mean estimator as ridgeless kernel regression with the NTK. This duality enables simple closed form equations describing the training dynamics,
generalization A generalization is a form of abstraction whereby common properties of specific instances are formulated as general concepts or claims. Generalizations posit the existence of a domain or set of elements, as well as one or more common character ...
, and predictions of wide neural networks. The NTK was introduced in 2018 by Arthur Jacot, Franck Gabriel and Clément Hongler, who used it to study the convergence and generalization properties of fully connected neural networks. Later works extended the NTK results to other neural network architectures. In fact, the phenomenon behind NTK is not specific to neural networks and can be observed in generic nonlinear models, usually by a suitable scaling.


Main results (informal)

Let f(x;\theta ) denote the scalar function computed by a given neural network with parameters \theta on input x. Then the neural tangent kernel is defined as\Theta (x,x';\theta )=\nabla _f(x;\theta )\cdot \nabla _f(x';\theta ).Since it is written as a
dot product In mathematics, the dot product or scalar productThe term ''scalar product'' means literally "product with a scalar as a result". It is also used sometimes for other symmetric bilinear forms, for example in a pseudo-Euclidean space. is an alg ...
between mapped inputs (with the
gradient In vector calculus, the gradient of a scalar-valued differentiable function of several variables is the vector field (or vector-valued function) \nabla f whose value at a point p is the "direction and rate of fastest increase". If the gr ...
of the neural network function serving as the feature map), we are guaranteed that the NTK is
symmetric Symmetry (from grc, συμμετρία "agreement in dimensions, due proportion, arrangement") in everyday language refers to a sense of harmonious and beautiful proportion and balance. In mathematics, "symmetry" has a more precise definit ...
and positive semi-definite. The NTK is thus a valid kernel function. Consider a fully connected neural network whose parameters are chosen
i.i.d. In probability theory and statistics, a collection of random variables is independent and identically distributed if each random variable has the same probability distribution as the others and all are mutually independent. This property is usua ...
according to any mean-zero distribution. This random initialization of \theta induces a distribution over f(x;\theta ) whose statistics we will analyze, both at initialization and throughout training (gradient descent on a specified dataset). We can visualize this distribution via a neural network ensemble which is constructed by drawing many times from the initial distribution over f(x;\theta ) and training each draw according to the same training procedure. The number of neurons in each layer is called the layer’s width. Consider taking the width of every hidden layer to infinity and training the neural network with
gradient descent In mathematics, gradient descent (also often called steepest descent) is a first-order iterative optimization algorithm for finding a local minimum of a differentiable function. The idea is to take repeated steps in the opposite direction of ...
(with a suitably small
learning rate In machine learning and statistics, the learning rate is a Hyperparameter (machine learning), tuning parameter in an Mathematical optimization, optimization algorithm that determines the step size at each iteration while moving toward a minimum of ...
). In this infinite-width limit, several nice properties emerge: * At initialization (before training), the neural network ensemble is a zero-mean
Gaussian process In probability theory and statistics, a Gaussian process is a stochastic process (a collection of random variables indexed by time or space), such that every finite collection of those random variables has a multivariate normal distribution, i.e. ...
(GP). This means that distribution of functions is the maximum-entropy distribution with mean \mathbb _ (x;\theta )0 and covariance \mathbb _ (x;\theta )f(x';\theta )\Sigma (x,x'), where the GP covariance \Sigma (x,x') can be computed from the network architecture. In other words, the distribution of neural network functions at initialization has no structure other than its first and second moments (mean and covariance). This follows from the central limit theorem. * The NTK is deterministic. In other words, the NTK is independent of the random parameter initialization. * The NTK does not change during training. * Each parameter changes negligibly throughout training. As Lee ''et al.'' note, "although individual parameters move by a vanishingly small amount, they collectively conspire to provide a finite change in the final output of the network, as is necessary for training." * During training, the neural network is linearized, i.e., its parameter dependence can be captured by its first-order
Taylor expansion In mathematics, the Taylor series or Taylor expansion of a function is an infinite sum of terms that are expressed in terms of the function's derivatives at a single point. For most common functions, the function and the sum of its Taylor se ...
: f(x;\theta_0 +\Delta \theta )=f(x;\theta_0 )+\Delta \theta \cdot \nabla _f(x;\theta_0 ), where \theta_0 are the initial parameters. This follows from the fact that each parameter changes negligibly during training. (The neural network remains nonlinear with respect to the inputs.) * The training dynamics are equivalent to kernel gradient descent using the NTK as the kernel. If the
loss function In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "co ...
is
mean-squared error In statistics, the mean squared error (MSE) or mean squared deviation (MSD) of an estimator (of a procedure for estimating an unobserved quantity) measures the average of the squares of the errors—that is, the average squared difference betwee ...
, the final distribution over f(x;\theta ) is still a Gaussian process, but with a new mean and covariance. In particular, the mean converges to the same estimator yielded by kernel regression with the NTK as kernel and zero ridge regularization, and the covariance is expressible in terms of the NTK and the initial GP covariance. It can be shown that the ensemble variance vanishes at the training points (in other words, the neural network always interpolates the training data, regardless of initialization).


Applications


Ridgeless kernel regression and kernel gradient descent

Kernel methods In machine learning, kernel machines are a class of algorithms for pattern analysis, whose best known member is the support-vector machine (SVM). The general task of pattern analysis is to find and study general types of relations (for example c ...
are machine learning algorithms which use only pairwise relations between input points. Kernel methods do not depend on the concrete values of the inputs; they only depend on the relations between the inputs and other inputs (such as the training set). These pairwise relations are fully captured by the kernel function: a
symmetric Symmetry (from grc, συμμετρία "agreement in dimensions, due proportion, arrangement") in everyday language refers to a sense of harmonious and beautiful proportion and balance. In mathematics, "symmetry" has a more precise definit ...
, positive-semidefinite function of two inputs which represents some notion of similarity between the two inputs. A fully equivalent condition is that there exists some feature map \mapsto \psi () such that the kernel function can be written as a
dot product In mathematics, the dot product or scalar productThe term ''scalar product'' means literally "product with a scalar as a result". It is also used sometimes for other symmetric bilinear forms, for example in a pseudo-Euclidean space. is an alg ...
of the mapped inputsK(,')=\psi ()\cdot \psi (').The properties of a kernel method depend on the choice of kernel function. (Note that \psi () may have higher dimension than \mathbf.) As a relevant example, consider
linear regression In statistics, linear regression is a linear approach for modelling the relationship between a scalar response and one or more explanatory variables (also known as dependent and independent variables). The case of one explanatory variable is ...
. This is the task of estimating ^ given N samples (_,y_) generated from y^()=^\cdot , where each \mathbf _ is drawn according to some input data distribution. In this setup, ^ is the weight vector which defines the true function y^; we wish to use the training samples to develop a model \mathbf which approximates ^. We do this by minimizing the mean-square error between our model and the training samples:=\arg \min _\sum_^, , y^(_)-\cdot _, , ^There exists an explicit solution for \mathbf which minimizes the squared error: =(^)^, where is the matrix whose columns are the training inputs, and is the vector of training outputs. Then, the model can make predictions on new inputs: ()=\cdot . However, this result can be rewritten as: ()=(^)(^)^. Note that this dual solution is expressed solely in terms of the inner products between inputs. This motivates extending linear regression to settings in which, instead of directly taking inner products between inputs, we first transform the inputs according to a chosen feature map and then evaluate the inner products between the transformed inputs. As discussed above, this can be captured by a kernel function K(,'), since all kernel functions are inner products of feature-mapped inputs. This yields the ridgeless kernel regression estimator:()=K(,)\;K(,)^\;.If the kernel matrix K(,) is
singular Singular may refer to: * Singular, the grammatical number that denotes a unit quantity, as opposed to the plural and other forms * Singular homology * SINGULAR, an open source Computer Algebra System (CAS) * Singular or sounder, a group of boar ...
, one uses the Moore-Penrose pseudoinverse. The regression equations are called "ridgeless" because they lack a ridge
regularization Regularization may refer to: * Regularization (linguistics) * Regularization (mathematics) * Regularization (physics) * Regularization (solid modeling) * Regularization Law, an Israeli law intended to retroactively legalize settlements See also ...
term. In this view, linear regression is a special case of kernel regression with the identity feature map: \psi ()=. Equivalently, kernel regression is simply linear regression in the feature space (i.e. the range of the feature map defined by the chosen kernel). Note that kernel regression is typically a ''nonlinear'' regression in the input space, which is a major strength of the algorithm. Just as it’s possible to perform linear regression using iterative optimization algorithms such as gradient descent, one can perform kernel regression using kernel gradient descent. This is equivalent to performing gradient descent in the feature space. It’s known that if the weight vector is initialized close to zero, least-squares gradient descent converges to the minimum-norm solution, i.e., the final weight vector has the minimum
Euclidean norm Euclidean space is the fundamental space of geometry, intended to represent physical space. Originally, that is, in Euclid's ''Elements'', it was the three-dimensional space of Euclidean geometry, but in modern mathematics there are Euclidea ...
of all the interpolating solutions. In the same way, kernel gradient descent yields the minimum-norm solution with respect to the RKHS norm. This is an example of the implicit regularization of gradient descent. The NTK gives a rigorous connection between the inference performed by infinite-width ANNs and that performed by
kernel methods In machine learning, kernel machines are a class of algorithms for pattern analysis, whose best known member is the support-vector machine (SVM). The general task of pattern analysis is to find and study general types of relations (for example c ...
: when the loss function is the least-squares loss, the inference performed by an ANN is in expectation equal to ridgeless kernel regression with respect to the NTK. This suggests that the performance of large ANNs in the NTK parametrization can be replicated by kernel methods for suitably chosen kernels.


Overparametrization, interpolation, and generalization

In overparametrized models, the number of tunable parameters exceeds the number of training samples. In this case, the model is able to memorize (perfectly fit) the training data. Therefore, overparametrized models interpolate the training data, achieving essentially zero training error. Kernel regression is typically viewed as a non-parametric learning algorithm, since there are no explicit parameters to tune once a kernel function has been chosen. An alternate view is to recall that kernel regression is simply linear regression in feature space, so the “effective” number of parameters is the dimension of the feature space. Therefore, studying kernels with high-dimensional feature maps can provide insights about strongly overparametrized models. As an example, consider the problem of generalization. According to classical statistics, memorization should cause models to fit noisy signals in the training data, harming their performance on unseen data. To mitigate this, machine learning algorithms often introduce regularization to mitigate noise-fitting tendencies. Surprisingly, modern neural networks (which tend to be strongly overparametrized) seem to generalize well, even in the absence of explicit regularization. To study the generalization properties of overparametrized neural networks, one can exploit the infinite-width duality with ridgeless kernel regression. Recent works have derived equations describing the expected generalization error of high-dimensional kernel regression; these results immediately explain the generalization of sufficiently wide neural networks trained to convergence on least-squares.


Convergence to a global minimum

For a
convex Convex or convexity may refer to: Science and technology * Convex lens, in optics Mathematics * Convex set, containing the whole line segment that joins points ** Convex polygon, a polygon which encloses a convex set of points ** Convex polytop ...
loss functional with a
global minimum In mathematical analysis, the maxima and minima (the respective plurals of maximum and minimum) of a function, known collectively as extrema (the plural of extremum), are the largest and smallest value of the function, either within a given r ...
, if the NTK remains
positive-definite In mathematics, positive definiteness is a property of any object to which a bilinear form or a sesquilinear form may be naturally associated, which is positive-definite In mathematics, positive definiteness is a property of any object to which a b ...
during training, the loss of the ANN \left(f\left(\cdot;\theta \left(t\right)\right)\right) converges to that minimum as t\to \infty. This positive-definiteness property has been shown in a number of cases, yielding the first proofs that large-width ANNs converge to global minima during training.


Extensions and limitations

The NTK can be studied for various ANN architectures, in particular
convolutional neural networks In deep learning, a convolutional neural network (CNN, or ConvNet) is a class of artificial neural network (ANN), most commonly applied to analyze visual imagery. CNNs are also known as Shift Invariant or Space Invariant Artificial Neural Networ ...
(CNNs),
recurrent neural networks A recurrent neural network (RNN) is a class of artificial neural networks where connections between nodes can create a cycle, allowing output from some nodes to affect subsequent input to the same nodes. This allows it to exhibit temporal dynamic ...
(RNNs) and
transformers ''Transformers'' is a media franchise produced by American toy company Hasbro and Japanese toy company Tomy, Takara Tomy. It primarily follows the Autobots and the Decepticons, two alien robot factions at war that can transform into other forms ...
. In such settings, the large-width limit corresponds to letting the number of parameters grow, while keeping the number of layers fixed: for CNNs, this involves letting the number of channels grow. Individual parameters of a wide neural network in the kernel regime change negligibly during training. However, this implies that infinite-width neural networks cannot exhibit
feature learning In machine learning, feature learning or representation learning is a set of techniques that allows a system to automatically discover the representations needed for feature detection or classification from raw data. This replaces manual feature ...
, which is widely considered to be an important property of realistic deep neural networks. This is not a generic feature of infinite-width neural networks and is largely due to a specific choice of the scaling by which the width is taken to the infinite limit; indeed several works have found alternate infinite-width scaling limits of neural networks in which there is no duality with kernel regression and feature learning occurs during training. Others introduce a "neural tangent hierarchy" to describe finite-width effects, which may drive feature learning. Neural Tangents is a
free and open-source Free and open-source software (FOSS) is a term used to refer to groups of software consisting of both free software and open-source software where anyone is freely licensed to use, copy, study, and change the software in any way, and the source ...
Python library used for computing and doing inference with the infinite width NTK and neural network Gaussian process (NNGP) corresponding to various common ANN architectures. In addition, there exists a
scikit-learn scikit-learn (formerly scikits.learn and also known as sklearn) is a free software machine learning library for the Python programming language. It features various classification, regression and clustering algorithms including support-vector ...
compatible implementation of the infinite width NTK for
Gaussian processes In probability theory and statistics, a Gaussian process is a stochastic process (a collection of random variables indexed by time or space), such that every finite collection of those random variables has a multivariate normal distribution, i.e. ...
calle
scikit-ntk


Details

When optimizing the parameters \theta\in\mathbb^ of an ANN to minimize an empirical loss through
gradient descent In mathematics, gradient descent (also often called steepest descent) is a first-order iterative optimization algorithm for finding a local minimum of a differentiable function. The idea is to take repeated steps in the opposite direction of ...
, the NTK governs the dynamics of the ANN output function f_ throughout the training.


Case 1: Scalar output

An ANN with scalar output consists of a family of functions f\left(\cdot,\theta\right):\mathbb^\to\mathbb parametrized by a vector of parameters \theta\in\mathbb^. The NTK is a kernel \Theta:\mathbb^\times\mathbb^\to\mathbb defined by\Theta\left(x,y;\theta\right)=\sum_^\partial_f\left(x;\theta\right)\partial_f\left(y;\theta\right).In the language of
kernel method In machine learning, kernel machines are a class of algorithms for pattern analysis, whose best known member is the support-vector machine (SVM). The general task of pattern analysis is to find and study general types of relations (for example ...
s, the NTK \Theta is the kernel associated with the
feature map Feature may refer to: Computing * Feature (CAD), could be a hole, pocket, or notch * Feature (computer vision), could be an edge, corner or blob * Feature (software design) is an intentional distinguishing characteristic of a software item ...
\left(x\mapsto\partial_f\left(x;\theta\right)\right)_. To see how this kernel drives the training dynamics of the ANN, consider a
dataset A data set (or dataset) is a collection of data. In the case of tabular data, a data set corresponds to one or more database tables, where every column of a table represents a particular variable, and each row corresponds to a given record of the d ...
\left(x_\right)_\subset\mathbb^ with scalar labels \left(z_\right)_\subset\mathbb and a
loss function In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "co ...
c:\mathbb\times\mathbb\to\mathbb. Then the associated empirical loss, defined on functions f:\mathbb^\to\mathbb, is given by\mathcal\left(f\right)=\sum_^c\left(f\left(x_\right),z_\right).When the ANN f\left(\cdot;\theta\right):\mathbb^\to\mathbb is trained to fit the dataset (i.e. minimize \mathcal) via continuous-time gradient descent, the parameters \left(\theta\left(t\right)\right)_ evolve through the
ordinary differential equation In mathematics, an ordinary differential equation (ODE) is a differential equation whose unknown(s) consists of one (or more) function(s) of one variable and involves the derivatives of those functions. The term ''ordinary'' is used in contras ...
: :\partial_\theta\left(t\right)=-\nabla\mathcal\left(f\left(\cdot;\theta\right)\right). During training the ANN output function follows an evolution differential equation given in terms of the NTK: :\partial_f\left(x;\theta\left(t\right)\right)=-\sum_^\Theta\left(x,x_;\theta\right)\partial_c\left(w,z_\right)\Big, _. This equation shows how the NTK drives the dynamics of f\left(\cdot;\theta\left(t\right)\right) in the space of functions \mathbb^\to\mathbb during training.


Case 2: Vector output

An ANN with vector output of size n_ consists in a family of functions f\left(\cdot;\theta\right):\mathbb^\to\mathbb^ parametrized by a vector of parameters \theta\in\mathbb^. In this case, the NTK \Theta:\mathbb^\times\mathbb^\to\mathcal_\left(\mathbb\right) is a matrix-valued kernel, with values in the space of n_\times n_ matrices, defined by\Theta_\left(x,y;\theta\right)=\sum_^\partial_f_\left(x;\theta\right)\partial_f_\left(y;\theta\right).
Empirical risk minimization Empirical risk minimization (ERM) is a principle in statistical learning theory which defines a family of learning algorithms and is used to give theoretical bounds on their performance. The core idea is that we cannot know exactly how well an a ...
proceeds as in the scalar case, with the difference being that the loss function takes vector inputs c:\mathbb^\times\mathbb^\to\mathbb. The training of f_ through continuous-time gradient descent yields the following evolution in function space driven by the NTK:\partial_f_\left(x;\theta\left(t\right)\right)=-\sum_^\sum_^\Theta_\left(x,x_;\theta\right)\partial_c\left(\left(w_,\ldots,w_\right),z_\right)\Big, _.This generalizes the equation shown in case 1 for scalar outputs.


Interpretation

Each data point x_ influences the evolution, of the output f\left(x;\theta\right) for each input x, throughout the training. More concretely, with respect to example i, the NTK value \Theta\left(x,x_;\theta\right) determines the influence of the loss gradient \partial_c\left(w,z_\right)\big, _ on the evolution of ANN output f\left(x;\theta\right) through a gradient descent step. In the scalar case, this readsf\left(x;\theta\left(t+\epsilon\right)\right)-f\left(x;\theta\left(t\right)\right)\approx\epsilon\sum_^\Theta\left(x,x_;\theta\left(t\right)\right)\partial_c\left(w,z_\right)\big, _.


Wide fully-connected ANNs have a deterministic NTK, which remains constant throughout training

Consider an ANN with fully-connected layers \ell=0,\ldots,L of widths n_=n_,n_,\ldots,n_=n_, so that f\left(\cdot;\theta\right)=R_\circ\cdots\circ R_, where R_=\sigma\circ A_ is the composition of an
affine transformation In Euclidean geometry, an affine transformation or affinity (from the Latin, ''affinis'', "connected with") is a geometric transformation that preserves lines and parallelism, but not necessarily Euclidean distances and angles. More generall ...
A_ with the pointwise application of a
nonlinearity In mathematics and science, a nonlinear system is a system in which the change of the output is not proportional to the change of the input. Nonlinear problems are of interest to engineers, biologists, physicists, mathematicians, and many other ...
\sigma:\mathbb\to\mathbb, where \theta parametrizes the maps A_,\ldots,A_. The parameters \theta\in\mathbb^ are initialized randomly, in an independent, identically distributed way. As the widths grow, the NTK's scale is affected by the exact parametrization of the A_'s and by the parameter initialization. This motivates the so-called NTK parametrization A_\left(x\right)=\fracW^x+b^. This parametrization ensures that if the parameters \theta\in\mathbb^ are initialized as standard normal variables, the NTK has a finite nontrivial limit. In the large-width limit, the NTK converges to a deterministic (non-random) limit \Theta_, which stays constant in time. The NTK \Theta_ is explicitly given by \Theta_=\Theta^, where \Theta^ is determined by the set of recursive equations: :\begin \Theta^\left(x,y\right) &= \Sigma^\left(x,y\right),\\ \Sigma^\left(x,y\right) &= \fracx^y+1,\\ \Theta^\left(x,y\right) &=\Theta^\left(x,y\right)\dot^\left(x,y\right)+\Sigma^\left(x,y\right),\\ \Sigma^\left(x,y\right) &= L_^\left(x,y\right),\\ \dot^\left(x,y\right) &= L_^, \end where L_^ denotes the kernel defined in terms of the Gaussian expectation: :L_^\left(x,y\right)=\mathbb_\left \left(X\right)f\left(Y\right)\right In this formula the kernels \Sigma^ are the ANN's so-called activation kernels.


Wide fully connected networks are linear in their parameters throughout training

The NTK describes the evolution of neural networks under gradient descent in function space. Dual to this perspective is an understanding of how neural networks evolve in parameter space, since the NTK is defined in terms of the gradient of the ANN's outputs with respect to its parameters. In the infinite width limit, the connection between these two perspectives becomes especially interesting. The NTK remaining constant throughout training at large widths co-occurs with the ANN being well described throughout training by its first order Taylor expansion around its parameters at initialization: :f\left(x;\theta(t)\right) = f\left(x;\theta(0)\right) + \nabla_f\left(x;\theta(0)\right) \left(\theta(t) - \theta(0)\right) + \mathcal\left(\min\left(n_1 \dots n_\right)^\right) .


See also

*
Large width limits of neural networks Artificial neural networks are a class of models used in machine learning, and inspired by biological neural networks. They are the core component of modern deep learning algorithms. Computation in artificial neural networks is usually organize ...


References

Kernel methods for machine learning


External links

*{{Cite web, first=Anil, last=Ananthaswamy, author-link=Anil Ananthaswamy, date=2021-10-11, title=A New Link to an Old Model Could Crack the Mystery of Deep Learning, url=https://www.quantamagazine.org/a-new-link-to-an-old-model-could-crack-the-mystery-of-deep-learning-20211011/, website=
Quanta Magazine ''Quanta Magazine'' is an editorially independent online publication of the Simons Foundation covering developments in physics, mathematics, biology and computer science Computer science is the study of computation, automation, and i ...