Neural Tangent Kernel(NTK)

Open In Collab

“In short, NTK represent the changes of the weights before and after the gradient descent update”

Let’s start the journey of revealing the black-box neural networks.

Setup a Neural Network

First of all, we need to define a simple neural network with 2 hidden layers

y(x,w)

where y is the neural network with weights wRm and, x,y¯N is the dataset which is a set of the input data and the output data with N data points. Since we focus on analyze the weight w, we simplify the notation y(x,w) as y(w)

Suppose we have a regression task on the network y(w), define

L(w)=1N12y(w)y¯22

where L(w) is our object loss which we want to minimize. Since the term 1N is regardless to our goal, just ignore it and we get a simpler loss function

L(w)=12y(w)y¯22

To The Limit of Infinite Width

In order to measure the difference of the weights during training the neural network, we define a normalized metric as following

wnw02w02

where wn and w0 are the weights at n-th training iteration and the initial weights. wnw02 means the quantity of the differnce between parameters wn and w0 and it is normalized by the 2-norm w02

losses with 3 widths

normalized weight-changes with 3 widths

As we can see, the difference of the weights during training decrease as the width of network grows. As a result, the trained weights should be very close to the inital weights w0 as the width of network goes to infinity.

Apply Taylor Expansion

We’ve known the Taylor expansion is

f(x)=n=0 fn(a)n!(xa)n

A function f(w) expanded on the w0 with first order approximation is

f(w)f(w0)+df(w0)dw(ww0)

It is trivial that if w is a vector, we need to replace the derivative df(w0)dw with gradient wf(w0)

f(w)f(w0)+wf(w0) (ww0)

Apply to the network y(w)

y(w)y(w0)+wy(w0) (ww0)

where wy(w0) and y(w0) are constants.

Thus, the Taylor expansion of y(w) is just a linear model. Though the expansion around w0 is regardless to the proof of NTK, it is still a useful tool to analyze the accuracy of the linear approximation with infinite-wide network.

However, the most difficult thing is how can we guarantee the approximation is accurate enough? It is so complex that I wouldn’t put it in this article but I will provide an intuitive explaination of what does NTK mean? in the following article. Please keep reading it if you are interested in it.

An Simpler Explaination Without Flow

Simply, we only consider a 1-dimension network f(x,w), w,x,y¯R for a dataset xX, y¯Y¯ which are input data points and output data points respectively.

First of all, let’s define the loss function of a neural network

L1(x,w)=12f(x,w)y¯22

The gradient descent is

wt+1=w0+η dL1(x,w)dw =w0+η(f(x,w)y¯)df(x,w)dw

where η is the learning rate.

NTK represent the changes of the weights before and after the gradient descent update. Thus, the changes of weights can be defined as

limη0f(x,w+η dL1(x,w)dw)f(x,w)η

=limη0f(x,w+η (f(x,w)y¯)df(x,w)dw)f(x,w)η

To simplify the notation, let η (f(x,w)y¯)df(x,w)dw)=Δw.

We can derive

limη0f(x,w+η (f(x,w)y¯)df(x,w)dw)f(x,w)η

=limη0f(x,w+Δw)f(x,w)η

Suppose the learning rate η is small enough and thus, ww+Δw. We can expand around w+Δw with Taylor expansion

f(x,w)f(x,w+Δw)+df(x,w+Δw)dw(w(w+Δw))

=f(x,w+Δw)df(x,w+Δw)dwΔw

We can get

limη0f(x,w+Δw)f(x,w)η

=limη0f(x,w+Δw)(f(x,w+Δw)df(x,w+Δw)dwΔw)η

=limη0 1ηdf(x,w+Δw)dwΔw

=limη0 df(x,w+η (f(x,w)y¯)df(x,w)dw)dw (f(x,w)y¯)df(x,w)dw

=df(x,w)dw (f(x,w)y¯)df(x,w)dw

Since the weight almost not change, let w=w0 and NTK is defined as

k1NTK(x,x)=df(x,w)dwdf(x,w)dw=df(x,w0)dwdf(x,w0)dw

Since f(x,w)y¯ would be very close to 0 while MSE is close to 0, we can simply ignore it. It is trivial that NTK represent the changes of weights before and after gradient descent. It measure the difference of weights quantitatively and thus we can approximate the process of gradient descent with Gaussian process.

Flow And Vector Field

So far, we’ve shown the neural tangent kernel on 1-width network. To move forward to the infinite-wide network, we need 2 tools to help us analyzing the process of gradient descent in high-dimensionalal. As a result, before diving into NTK more deeply, we need to understand what is Gradient Flow and Vector Field.

Vector Field

Define a space χRd with d dimensions and a point of the space xRd. A hyperplane f(x) f:χR. As we want to find the global minimal point x

x=argminxX f(x)

The gradient of the hyperplane x f:χRd represent the gradients of each point on the hyperplane f.

Then, we define a vector field F:χRd assigning the velocity vector to each points of the space. Mathmatically, the vector field F has the same function space as the gradient xf. As a result, we can also see the gradient xf as a vector field xf=F(x) which assigns the velocity vector vR to each point xχ.

F(x)=xf(x)=v

A hyperplane and the gradients can be illustrated as the following figure. The orange surface represents the hyperplane f and the corresponding gradient xf of each points xχ on the hyperplane f is the blue arrows in the bottom. Note that the gradients xf(x) here are ascent while gradients of our optimization problem are descent xf(x). They have oppsite direction. Intuitively, the gradients represent the direction and steepness of the points on the hyperplane while the vector field is the velocity vector of the points. Mathmatically, the gradients and the vector field have the same function space, so we let them be equal but not due to the physical perspective.

Then we introduce another variable time. Let c(t) for c:RRd represent the dynamics of along the time t. The function c(t) gives the position in the space χRd along time t.

As a result, we know

c(t+δ)=c(t)+δF(c(t))=c(t)δxf(c(t))

where δ represent the time-step of 2 positions. δF(c(t))=δxf(c(t)) means time products velocity vector and then get the movment vector during the time δ.

Gradient Flow

The gradient flow is defined as

X˙(t)=F(c(t))=xf(c(t))=xf(c(t))c˙(t)=xf(c(t))dc(t)dt

The gradient flow describe changing gradients along time.

Combined With Gradient Flow

We’ve know the update of the gradient descent is

wt+1=wtηwL(wt)

Let the function w(t)=wt and define the gradient flow over weights is w˙(t)

w˙(t)=wL(w(t))

Actually, the meaning of the gradient flow w˙(t)=dw(t)dt is likey the changing direction of gradient descent along time.

We expand the gradient of the loss function with chain rule

w˙(t)=wL(w(t))

=w12y(w(t))y¯22

=122wy(w(t))(y(w(t))y¯)

=wy(w(t))(y(w(t))y¯)

Now we can derive the flow of the network y˙(w(t))

y˙(w(t))=wy(w(t))w˙(t)

=wy(w(t))wy(w)(y(w(t))y¯)

=wy(w(t))wy(w(t))(y(w(t))y¯)

To simplify the notation, we replace the dynamics w(t) with wt.

w(t)=wt

Thus, we get

y˙(wt)=wy(wt)wy(wt)(y(wt)y¯)

However, we’ve known the mathmatical form of the flow y˙(wt), but what’s the meaning of the flow y˙(wt)? Well, we can see the updated weights wt during the gradient descent as a trajectory in a high-dimensional space. Since the learning rate η is quite small, the the difference of weights wt between before and after the gradient descent is very small. As a result, we can see the discrete porgress of the graient descent as a continuous trjectory like the following figure. The flow over the neural network y˙(wt) is actually the tangent line of wt. The flow y˙(wt) describe the velocity vector of the point wt and can predicts close-enough next point wt+1.

Since y(wt)y¯ would be very close to 0, too while MSE is close to 0, we can simply ignore it.

Actually, we are now very close to the neural tangent kernel(NTK). The NTK is a kernel matrix defined as

KNTK(x,x)=wy(x,wt)wy(x,wt)

Since the weights of the infinite-wide network doesn’t change during the training.

y(wt)y(w0)

We get

wy(wt)wy(wt)wy(w0)wy(w0)=wy(x,w0)wy(x,w0)

=KNTK(x,x)

Again, KNTK(x,x) is the Neural Tangent Kernel, NTK.

The way here to measure the distance between 2 tangents is the Cosine Similarity with inner product. The cosine value of 2 identical vector is 1 and 2 orthorgonal vectors is 0 which are totally different. With an additional minus sign, we can regard the negative similarity as a kind of distance.

To summary, the weights of an infinite-wide network almost don’t change during training. As a result, the kernel always stay almost the same. We can use NTK to analyze many properties of neural network and the neural networks are no longer black boxes.

Papers

NNGP

NTK

Reference

Thank for the following posts / people sincerely.

Gaussian Distribution

NNGP

NTK

Flow

Taylor Expansion