Reference

Reference

Contents

AutoGrad

AutoGrad.gradFunction.
grad(fun, argnum=1)

Take a function fun(X...)->Y and return another function gfun(X...)->dXi which computes its gradient with respect to positional argument number argnum. The function fun should be scalar-valued. The returned function gfun takes the same arguments as fun, but returns the gradient instead. The gradient has the same type and size as the target argument which can be a Number, Array, Tuple, or Dict.

source
AutoGrad.gradlossFunction.
gradloss(fun, argnum=1)

Another version of grad where the generated function returns a (gradient,value) pair.

source
AutoGrad.gradcheckFunction.
gradcheck(f, w, x...; kwargs...)

Numerically check the gradient of f(w,x...;o...) with respect to its first argument w and return a boolean result.

The argument w can be a Number, Array, Tuple or Dict which in turn can contain other Arrays etc. Only the largest 10 entries in each numerical gradient array are checked by default. If the output of f is not a number, gradcheck constructs and checks a scalar function by taking its dot product with a random vector.

Keywords

  • gcheck=10: number of largest entries from each numeric array in gradient dw=(grad(f))(w,x...;o...) compared to their numerical estimates.

  • verbose=false: print detailed messages if true.

  • kwargs=[]: keyword arguments to be passed to f.

  • delta=atol=rtol=cbrt(eps(w)): tolerance parameters. See isapprox for their meaning.

source

KnetArray

Knet.KnetArrayType.
KnetArray{T}(dims)
KnetArray(a::AbstractArray)
Array(k::KnetArray)

Container for GPU arrays that supports most of the AbstractArray interface. The constructor allocates a KnetArray in the currently active device, as specified by gpu(). KnetArrays and Arrays can be converted to each other as shown above, which involves copying to and from the GPU memory. Only Float32/64 KnetArrays are fully supported.

Important differences from the alternative CudaArray are: (1) a custom memory manager that minimizes the number of calls to the slow cudaMalloc by reusing already allocated but garbage collected GPU pointers. (2) a custom getindex that handles ranges such as a[5:10] as views with shared memory instead of copies. (3) custom CUDA kernels that implement elementwise, broadcasting, and reduction operations.

Supported functions:

  • Indexing: getindex, setindex! with the following index types:

    • 1-D: Real, Colon, OrdinalRange, AbstractArray{Real}, AbstractArray{Bool}, CartesianIndex, AbstractArray{CartesianIndex}, EmptyArray, KnetArray{Int32} (low level), KnetArray{0/1} (using float for BitArray) (1-D includes linear indexing of multidimensional arrays)

    • 2-D: (Colon,Union{Real,Colon,OrdinalRange,AbstractVector{Real},AbstractVector{Bool},KnetVector{Int32}}), (Union{Real,AbstractUnitRange,Colon}...) (in any order)

    • N-D: (Real...)

  • Array operations: ==, !=, cat, convert, copy, copy!, deepcopy, display, eachindex, eltype, endof, fill!, first, hcat, isapprox, isempty, length, ndims, ones, pointer, rand!, randn!, reshape, similar, size, stride, strides, summary, vcat, vec, zeros. (cat(i,x,y) supported for i=1,2.)

  • Math operators: (-), abs, abs2, acos, acosh, asin, asinh, atan, atanh, cbrt, ceil, cos, cosh, cospi, erf, erfc, erfcinv, erfcx, erfinv, exp, exp10, exp2, expm1, floor, log, log10, log1p, log2, round, sign, sin, sinh, sinpi, sqrt, tan, tanh, trunc

  • Broadcasting operators: (.*), (.+), (.-), (./), (.<), (.<=), (.!=), (.==), (.>), (.>=), (.^), max, min. (Boolean operators generate outputs with same type as inputs; no support for KnetArray{Bool}.)

  • Reduction operators: countnz, maximum, mean, minimum, prod, sum, sumabs, sumabs2, vecnorm.

  • Linear algebra: (*), axpy!, permutedims (up to 5D), transpose

  • Knet extras: relu, sigm, invx, logp, logsumexp, conv4, pool, deconv4, unpool, mat, update! (Only 4D/5D, Float32/64 KnetArrays support conv4, pool, deconv4, unpool)

Memory management

Knet models do not overwrite arrays which need to be preserved for gradient calculation. This leads to a lot of allocation and regular GPU memory allocation is prohibitively slow. Fortunately most models use identically sized arrays over and over again, so we can minimize the number of actual allocations by reusing preallocated but garbage collected pointers.

When Julia gc reclaims a KnetArray, a special finalizer keeps its pointer in a table instead of releasing the memory. If an array with the same size in bytes is later requested, the same pointer is reused. The exact algorithm for allocation is:

  1. Try to find a previously allocated and garbage collected pointer in the current device. (0.5 μs)

  2. If not available, try to allocate a new array using cudaMalloc. (10 μs)

  3. If not successful, try running gc() and see if we get a pointer of the right size. (75 ms, but this should be amortized over all reusable pointers that become available due to the gc)

  4. Finally if all else fails, clean up all saved pointers in the current device using cudaFree and try allocation one last time. (25-70 ms, however this causes the elimination of all reusable pointers)

source

Utilities

Knet.accuracyFunction.
accuracy(scores, answers, d=1; average=true)

Given an unnormalized scores matrix and an Integer array of correct answers, return the ratio of instances where the correct answer has the maximum score. d=1 means instances are in columns, d=2 means instances are in rows. Use average=false to return the number of correct answers instead of the ratio.

source
accuracy(model, data, predict; average=true)

Compute accuracy(predict(model,x), y) for (x,y) in data and return the ratio (if average=true) or the count (if average=false) of correct answers.

source
Knet.dirFunction.
Knet.dir(path...)

Construct a path relative to Knet root.

Example

julia> Knet.dir("examples","mnist.jl")
"/home/dyuret/.julia/v0.5/Knet/examples/mnist.jl"
source
Knet.dropoutFunction.
dropout(x, p)

Given an array x and probability 0<=p<=1, just return x if testing, return an array y in which each element is 0 with probability p or x[i]/(1-p) with probability 1-p if training. Training mode is detected automatically based on the type of x, which is AutoGrad.Rec during gradient calculation. Use the keyword argument training::Bool to change the default mode and seed::Number to set the random number seed for reproducible results. See (Srivastava et al. 2014) for a reference.

source
Knet.gpuFunction.

gpu() returns the id of the active GPU device or -1 if none are active.

gpu(true) resets all GPU devices and activates the one with the most available memory.

gpu(false) resets and deactivates all GPU devices.

gpu(d::Int) activates the GPU device d if 0 <= d < gpuCount(), otherwise deactivates devices.

gpu(true/false) resets all devices. If there are any allocated KnetArrays their pointers will be left dangling. Thus gpu(true/false) should only be used during startup. If you want to suspend GPU use temporarily, use gpu(-1).

gpu(d::Int) does not reset the devices. You can select a previous device and find allocated memory preserved. However trying to operate on arrays of an inactive device will result in error.

source
Knet.invxFunction.

invx(x) = (1./x)

source
Knet.knetgcFunction.
knetgc(dev=gpu())

cudaFree all pointers allocated on device dev that were previously allocated and garbage collected. Normally Knet holds on to all garbage collected pointers for reuse. Try this if you run out of GPU memory.

source
Knet.logpFunction.
logp(x,[dims])

Treat entries in x as as unnormalized log probabilities and return normalized log probabilities.

dims is an optional argument, if not specified the normalization is over the whole x, otherwise the normalization is performed over the given dimensions. In particular, if x is a matrix, dims=1 normalizes columns of x and dims=2 normalizes rows of x.

source
Knet.logsumexpFunction.
logsumexp(x,[dims])

Compute log(sum(exp(x),dims)) in a numerically stable manner.

dims is an optional argument, if not specified the summation is over the whole x, otherwise the summation is performed over the given dimensions. In particular if x is a matrix, dims=1 sums columns of x and dims=2 sums rows of x.

source
Knet.minibatchFunction.
minibatch(x, y, batchsize; shuffle, partial, xtype, ytype)

Return an iterable of minibatches [(xi,yi)...] given data tensors x, y and batchsize. The last dimension of x and y should match and give the number of instances. Keyword arguments:

  • shuffle=false: Shuffle the instances before minibatching.

  • partial=false: If true include the last partial minibatch < batchsize.

  • xtype=typeof(x): Convert xi in minibatches to this type.

  • ytype=typeof(y): Convert yi in minibatches to this type.

source
minibatch(x, batchsize; shuffle, partial, xtype, ytype)

Return an iterable of minibatches [x1,x2,...] given data tensor x and batchsize. The last dimension of x gives the number of instances. Keyword arguments:

  • shuffle=false: Shuffle the instances before minibatching.

  • partial=false: If true include the last partial minibatch < batchsize.

  • xtype=typeof(x): Convert xi in minibatches to this type.

source
Knet.nllFunction.
nll(scores, answers, d=1; average=true)

Given an unnormalized scores matrix and an Integer array of correct answers, return the per-instance negative log likelihood. d=1 means instances are in columns, d=2 means instances are in rows. Use average=false to return the sum instead of per-instance average.

source
nll(model, data, predict; average=true)

Compute nll(predict(model,x), y) for (x,y) in data and return the per-instance average (if average=true) or total (if average=false) negative log likelihood.

source
Knet.reluFunction.

relu(x) = max(0,x)

source
Knet.setseedFunction.
setseed(n::Integer)

Run srand(n) on both cpu and gpu.

source
Knet.sigmFunction.

sigm(x) = (1./(1+exp(-x)))

source

Convolution and Pooling

Knet.conv4Function.
conv4(w, x; kwargs...)

Execute convolutions or cross-correlations using filters specified with w over tensor x.

Currently KnetArray{Float32/64,4/5} and Array{Float32/64,4} are supported as w and x. If w has dimensions (W1,W2,...,I,O) and x has dimensions (X1,X2,...,I,N), the result y will have dimensions (Y1,Y2,...,O,N) where

Yi=1+floor((Xi+2*padding[i]-Wi)/stride[i])

Here I is the number of input channels, O is the number of output channels, N is the number of instances, and Wi,Xi,Yi are spatial dimensions. padding and stride are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.

Keywords

  • padding=0: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.

  • stride=1: the number of elements to slide to reach the next filtering window.

  • upscale=1: upscale factor for each dimension.

  • mode=0: 0 for convolution and 1 for cross-correlation.

  • alpha=1: can be used to scale the result.

  • handle: handle to a previously created cuDNN context. Defaults to a Knet allocated handle.

source
Knet.deconv4Function.

Deconvolution; reverse of convolution.

source
Knet.matFunction.
mat(x)

Reshape x into a two-dimensional matrix.

This is typically used when turning the output of a 4-D convolution result into a 2-D input for a fully connected layer. For 1-D inputs returns reshape(x, (length(x),1)). For inputs with more than two dimensions of size (X1,X2,...,XD), returns

reshape(x, (X1*X2*...*X[D-1],XD))
source
Knet.poolFunction.
pool(x; kwargs...)

Compute pooling of input values (i.e., the maximum or average of several adjacent values) to produce an output with smaller height and/or width.

Currently 4 or 5 dimensional KnetArrays with Float32 or Float64 entries are supported. If x has dimensions (X1,X2,...,I,N), the result y will have dimensions (Y1,Y2,...,I,N) where

Yi=1+floor((Xi+2*padding[i]-window[i])/stride[i])

Here I is the number of input channels, N is the number of instances, and Xi,Yi are spatial dimensions. window, padding and stride are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.

Keywords:

  • window=2: the pooling window size for each dimension.

  • padding=0: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.

  • stride=window: the number of elements to slide to reach the next pooling window.

  • mode=0: 0 for max, 1 for average including padded values, 2 for average excluding padded values.

  • maxpoolingNanOpt=0: Nan numbers are not propagated if 0, they are propagated if 1.

  • alpha=1: can be used to scale the result.

  • handle: Handle to a previously created cuDNN context. Defaults to a Knet allocated handle.

source
Knet.unpoolFunction.

Unpooling; reverse of pooling.

x == pool(unpool(x;o...); o...)
source

Recurrent neural networks

Knet.rnninitFunction.
rnninit(inputSize, hiddenSize; opts...)

Return an (r,w) pair where r is a RNN struct and w is a single weight array that includes all matrices and biases for the RNN. Keyword arguments:

  • rnnType=:lstm Type of RNN: One of :relu, :tanh, :lstm, :gru.

  • numLayers=1: Number of RNN layers.

  • bidirectional=false: Create a bidirectional RNN if true.

  • dropout=0.0: Dropout probability. Ignored if numLayers==1.

  • skipInput=false: Do not multiply the input with a matrix if true.

  • dataType=Float32: Data type to use for weights.

  • algo=0: Algorithm to use, see CUDNN docs for details.

  • seed=0: Random number seed. Uses time() if 0.

  • winit=xavier: Weight initialization method for matrices.

  • binit=zeros: Weight initialization method for bias vectors.

  • `usegpu=(gpu()>=0): GPU used by default if one exists.

RNNs compute the output h[t] for a given iteration from the recurrent input h[t-1] and the previous layer input x[t] given matrices W, R and biases bW, bR from the following equations:

:relu and :tanh: Single gate RNN with activation function f:

h[t] = f(W * x[t] .+ R * h[t-1] .+ bW .+ bR)

:gru: Gated recurrent unit:

i[t] = sigm(Wi * x[t] .+ Ri * h[t-1] .+ bWi .+ bRi) # input gate
r[t] = sigm(Wr * x[t] .+ Rr * h[t-1] .+ bWr .+ bRr) # reset gate
n[t] = tanh(Wn * x[t] .+ r[t] .* (Rn * h[t-1] .+ bRn) .+ bWn) # new gate
h[t] = (1 - i[t]) .* n[t] .+ i[t] .* h[t-1]

:lstm: Long short term memory unit with no peephole connections:

i[t] = sigm(Wi * x[t] .+ Ri * h[t-1] .+ bWi .+ bRi) # input gate
f[t] = sigm(Wf * x[t] .+ Rf * h[t-1] .+ bWf .+ bRf) # forget gate
o[t] = sigm(Wo * x[t] .+ Ro * h[t-1] .+ bWo .+ bRo) # output gate
n[t] = tanh(Wn * x[t] .+ Rn * h[t-1] .+ bWn .+ bRn) # new gate
c[t] = f[t] .* c[t-1] .+ i[t] .* n[t]               # cell output
h[t] = o[t] .* tanh(c[t])
source
Knet.rnnforwFunction.
rnnforw(r, w, x[, hx, cx]; batchSizes, hy, cy)

Returns a tuple (y,hyout,cyout,rs) given rnn r, weights w, input x and optionally the initial hidden and cell states hx and cx (cx is only used in LSTMs). r and w should come from a previous call to rnninit. Both hx and cx are optional, they are treated as zero arrays if not provided. The output y contains the hidden states of the final layer for each time step, hyout and cyout give the final hidden and cell states for all layers, rs is a buffer the RNN needs for its gradient calculation.

The boolean keyword arguments hy and cy control whether hyout and cyout will be output. By default hy = (hx!=nothing) and cy = (cx!=nothing && r.mode==2), i.e. a hidden state will be output if one is provided as input and for cell state we also require an LSTM. If hy/cy is false, hyout/cyout will be nothing. batchSizes can be an integer array that specifies non-uniform batch sizes as explained below. By default batchSizes=nothing and the same batch size, size(x,2), is used for all time steps.

The input and output dimensions are:

  • x: (X,[B,T])

  • y: (H/2H,[B,T])

  • hx,cx,hyout,cyout: (H,B,L/2L)

  • batchSizes: nothing or Vector{Int}(T)

where X is inputSize, H is hiddenSize, B is batchSize, T is seqLength, L is numLayers. x can be 1, 2, or 3 dimensional. If batchSizes==nothing, a 1-D x represents a single instance, a 2-D x represents a single minibatch, and a 3-D x represents a sequence of identically sized minibatches. If batchSizes is an array of (non-increasing) integers, it gives us the batch size for each time step in the sequence, in which case sum(batchSizes) should equal div(length(x),size(x,1)). y has the same dimensionality as x, differing only in its first dimension, which is H if the RNN is unidirectional, 2H if bidirectional. Hidden vectors hx, cx, hyout, cyout all have size (H,B1,L) for unidirectional RNNs, and (H,B1,2L) for bidirectional RNNs where B1 is the size of the first minibatch.

source
Knet.rnnparamFunction.
rnnparam{T}(r::RNN, w::KnetArray{T}, layer, id, param)

Return a single weight matrix or bias vector as a slice of w.

Valid layer values:

  • For unidirectional RNNs 1:numLayers

  • For bidirectional RNNs 1:2*numLayers, forw and back layers alternate.

Valid id values:

  • For RELU and TANH RNNs, input = 1, hidden = 2.

  • For GRU reset = 1,4; update = 2,5; newmem = 3,6; 1:3 for input, 4:6 for hidden

  • For LSTM inputgate = 1,5; forget = 2,6; newmem = 3,7; output = 4,8; 1:4 for input, 5:8 for hidden

Valid param values:

  • Return the weight matrix (transposed!) if param==1.

  • Return the bias vector if param==2.

The effect of skipInput: Let I=1 for RELU/TANH, 1:3 for GRU, 1:4 for LSTM

  • For skipInput=false (default), rnnparam(r,w,1,I,1) is a (inputSize,hiddenSize) matrix.

  • For skipInput=true, rnnparam(r,w,1,I,1) is nothing.

  • For bidirectional, the same applies to rnnparam(r,w,2,I,1): the first back layer.

source
Knet.rnnparamsFunction.
rnnparams(r::RNN, w)

Split w into individual parameters and return them as an array.

The order of params returned (subject to change):

  • All weight matrices come before all bias vectors.

  • Matrices and biases are sorted lexically based on (layer,id).

  • See @doc rnnparam for valid layer and id values.

  • Input multiplying matrices are nothing if r.inputMode = 1.

source

Optimization methods

Knet.update!Function.
update!(weights, gradients, params)
update!(weights, gradients; lr=0.001, gclip=0)

Update the weights using their gradients and the optimization algorithm parameters specified by params. The 2-arg version defaults to the Sgd algorithm with learning rate lr and gradient clip gclip. gclip==0 indicates no clipping. The weights and possibly gradients and params are modified in-place.

weights can be an individual numeric array or a collection of arrays represented by an iterator or dictionary. In the individual case, gradients should be a similar numeric array of size(weights) and params should be a single object. In the collection case, each individual weight array should have a corresponding params object. This way different weight arrays can have their own optimization state, different learning rates, or even different optimization algorithms running in parallel. In the iterator case, gradients and params should be iterators of the same length as weights with corresponding elements. In the dictionary case, gradients and params should be dictionaries with the same keys as weights.

Individual optimization parameters can be one of the following types. The keyword arguments for each type's constructor and their default values are listed as well.

  • Sgd(;lr=0.001, gclip=0)

  • Momentum(;lr=0.001, gclip=0, gamma=0.9)

  • Nesterov(;lr=0.001, gclip=0, gamma=0.9)

  • Rmsprop(;lr=0.001, gclip=0, rho=0.9, eps=1e-6)

  • Adagrad(;lr=0.1, gclip=0, eps=1e-6)

  • Adadelta(;lr=0.01, gclip=0, rho=0.9, eps=1e-6)

  • Adam(;lr=0.001, gclip=0, beta1=0.9, beta2=0.999, eps=1e-8)

Example:

w = rand(d)                 # an individual weight array
g = lossgradient(w)         # gradient g has the same shape as w
update!(w, g)               # update w in-place with Sgd()
update!(w, g; lr=0.1)       # update w in-place with Sgd(lr=0.1)
update!(w, g, Sgd(lr=0.1))  # update w in-place with Sgd(lr=0.1)

w = (rand(d1), rand(d2))    # a tuple of weight arrays
g = lossgradient2(w)        # g will also be a tuple
p = (Adam(), Sgd())         # p has params for each w[i]
update!(w, g, p)            # update each w[i] in-place with g[i],p[i]

w = Any[rand(d1), rand(d2)] # any iterator can be used
g = lossgradient3(w)        # g will be similar to w
p = Any[Adam(), Sgd()]      # p should be an iterator of same length
update!(w, g, p)            # update each w[i] in-place with g[i],p[i]

w = Dict(:a => rand(d1), :b => rand(d2)) # dictionaries can be used
g = lossgradient4(w)
p = Dict(:a => Adam(), :b => Sgd())
update!(w, g, p)
source
Knet.optimizersFunction.
optimizers(model, otype; options...)

Given parameters of a model, initialize and return corresponding optimization parameters for a given optimization type otype and optimization options options. This is useful because each numeric array in model needs its own distinct optimization parameter. optimizers makes the creation of optimization parameters that parallel model parameters easy when all of them use the same type and options.

source
Knet.AdadeltaType.
Adadelta(;lr=0.01, gclip=0, rho=0.9, eps=1e-6)
update!(w,g,p::Adadelta)

Container for parameters of the Adadelta optimization algorithm used by update!.

Adadelta is an extension of Adagrad that tries to prevent the decrease of the learning rates to zero as training progresses. It scales the learning rate based on the accumulated gradients like Adagrad and holds the acceleration term like Momentum. It updates the weights with the following formulas:

G = (1-rho) * g .^ 2 + rho * G
update = g .* sqrt(delta + eps) ./ sqrt(G + eps)
w = w - lr * update
delta = rho * delta + (1-rho) * update .^ 2

where w is the weight, g is the gradient of the objective function w.r.t w, lr is the learning rate, G is an array with the same size and type of w and holds the sum of the squares of the gradients. eps is a small constant to prevent a zero value in the denominator. rho is the momentum parameter and delta is an array with the same size and type of w and holds the sum of the squared updates.

If vecnorm(g) > gclip > 0, g is scaled so that its norm is equal to gclip. If gclip==0 no scaling takes place.

Reference: Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method.

source
Knet.AdagradType.
Adagrad(;lr=0.1, gclip=0, eps=1e-6)
update!(w,g,p::Adagrad)

Container for parameters of the Adagrad optimization algorithm used by update!.

Adagrad is one of the methods that adapts the learning rate to each of the weights. It stores the sum of the squares of the gradients to scale the learning rate. The learning rate is adapted for each weight by the value of current gradient divided by the accumulated gradients. Hence, the learning rate is greater for the parameters where the accumulated gradients are small and the learning rate is small if the accumulated gradients are large. It updates the weights with the following formulas:

G = G + g .^ 2
w = w - g .* lr ./ sqrt(G + eps)

where w is the weight, g is the gradient of the objective function w.r.t w, lr is the learning rate, G is an array with the same size and type of w and holds the sum of the squares of the gradients. eps is a small constant to prevent a zero value in the denominator.

If vecnorm(g) > gclip > 0, g is scaled so that its norm is equal to gclip. If gclip==0 no scaling takes place.

Reference: Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159.

source
Knet.AdamType.
Adam(;lr=0.001, gclip=0, beta1=0.9, beta2=0.999, eps=1e-8)
update!(w,g,p::Adam)

Container for parameters of the Adam optimization algorithm used by update!.

Adam is one of the methods that compute the adaptive learning rate. It stores accumulated gradients (first moment) and the sum of the squared of gradients (second). It scales the first and second moment as a function of time. Here is the update formulas:

m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g .* g
mhat = m ./ (1 - beta1 ^ t)
vhat = v ./ (1 - beta2 ^ t)
w = w - (lr / (sqrt(vhat) + eps)) * mhat

where w is the weight, g is the gradient of the objective function w.r.t w, lr is the learning rate, m is an array with the same size and type of w and holds the accumulated gradients. v is an array with the same size and type of w and holds the sum of the squares of the gradients. eps is a small constant to prevent a zero denominator. beta1 and beta2 are the parameters to calculate bias corrected first and second moments. t is the update count.

If vecnorm(g) > gclip > 0, g is scaled so that its norm is equal to gclip. If gclip==0 no scaling takes place.

Reference: Kingma, D. P., & Ba, J. L. (2015). Adam: a Method for Stochastic Optimization. International Conference on Learning Representations, 1–13.

source
Knet.MomentumType.
Momentum(;lr=0.001, gclip=0, gamma=0.9)
update!(w,g,p::Momentum)

Container for parameters of the Momentum optimization algorithm used by update!.

The Momentum method tries to accelerate SGD by adding a velocity term to the update. This also decreases the oscillation between successive steps. It updates the weights with the following formulas:

velocity = gamma * velocity + lr * g
w = w - velocity

where w is a weight array, g is the gradient of the objective function w.r.t w, lr is the learning rate, gamma is the momentum parameter, velocity is an array with the same size and type of w and holds the accelerated gradients.

If vecnorm(g) > gclip > 0, g is scaled so that its norm is equal to gclip. If gclip==0 no scaling takes place.

Reference: Qian, N. (1999). On the momentum term in gradient descent learning algorithms. Neural Networks : The Official Journal of the International Neural Network Society, 12(1), 145–151.

source
Knet.NesterovType.
Nesterov(; lr=0.001, gclip=0, gamma=0.9)
update!(w,g,p::Momentum)

Container for parameters of Nesterov's momentum optimization algorithm used by update!.

It is similar to standard Momentum but with a slightly different update rule:

velocity = gamma * velocity_old - lr * g
w = w_old - velocity_old + (1+gamma) * velocity

where w is a weight array, g is the gradient of the objective function w.r.t w, lr is the learning rate, gamma is the momentum parameter, velocity is an array with the same size and type of w and holds the accelerated gradients.

If vecnorm(g) > gclip > 0, g is scaled so that its norm is equal to gclip. If gclip == 0 no scaling takes place.

Reference Implementation : Yoshua Bengio, Nicolas Boulanger-Lewandowski and Razvan P ascanu

source
Knet.RmspropType.
Rmsprop(;lr=0.001, gclip=0, rho=0.9, eps=1e-6)
update!(w,g,p::Rmsprop)

Container for parameters of the Rmsprop optimization algorithm used by update!.

Rmsprop scales the learning rates by dividing the root mean squared of the gradients. It updates the weights with the following formula:

G = (1-rho) * g .^ 2 + rho * G
w = w - lr * g ./ sqrt(G + eps)

where w is the weight, g is the gradient of the objective function w.r.t w, lr is the learning rate, G is an array with the same size and type of w and holds the sum of the squares of the gradients. eps is a small constant to prevent a zero value in the denominator. rho is the momentum parameter and delta is an array with the same size and type of w and holds the sum of the squared updates.

If vecnorm(g) > gclip > 0, g is scaled so that its norm is equal to gclip. If gclip==0 no scaling takes place.

Reference: Tijmen Tieleman and Geoffrey Hinton (2012). "Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude." COURSERA: Neural Networks for Machine Learning 4.2.

source
Knet.SgdType.
Sgd(;lr=0.001,gclip=0)
update!(w,g,p::Sgd)
update!(w,g;lr=0.001)

Container for parameters of the Stochastic gradient descent (SGD) optimization algorithm used by update!.

SGD is an optimization technique to minimize an objective function by updating its weights in the opposite direction of their gradient. The learning rate (lr) determines the size of the step. SGD updates the weights with the following formula:

w = w - lr * g

where w is a weight array, g is the gradient of the loss function w.r.t w and lr is the learning rate.

If vecnorm(g) > gclip > 0, g is scaled so that its norm is equal to gclip. If gclip==0 no scaling takes place.

SGD is used by default if no algorithm is specified in the two argument version of update![@ref].

source

Hyperparameter optimization

Knet.goldensectionFunction.
goldensection(f,n;kwargs) => (fmin,xmin)

Find the minimum of f using concurrent golden section search in n dimensions. See Knet.goldensection_demo() for an example.

f is a function from a Vector{Float64} of length n to a Number. It can return NaN for out of range inputs. Goldensection will always start with a zero vector as the initial input to f, and the initial step size will be 1 in each dimension. The user should define f to scale and shift this input range into a vector meaningful for their application. For positive inputs like learning rate or hidden size, you can use a transformation such as x0*exp(x) where x is a value goldensection passes to f and x0 is your initial guess for this value. This will effectively start the search at x0, then move with multiplicative steps.

I designed this algorithm combining ideas from Golden Section Search and Hill Climbing Search. It essentially runs golden section search concurrently in each dimension, picking the next step based on estimated gain.

Keyword arguments

  • dxmin=0.1: smallest step size.

  • accel=φ: acceleration rate. Golden ratio φ=1.618... is best.

  • verbose=false: use true to print individual steps.

  • history=[]: cache of [(x,f(x)),...] function evaluations.

source
Knet.hyperbandFunction.
hyperband(getconfig, getloss, maxresource=27, reduction=3)

Hyperparameter optimization using the hyperband algorithm from (Lisha et al. 2016). You can try a simple MNIST example using Knet.hyperband_demo().

Arguments

  • getconfig() returns random configurations with a user defined type and distribution.

  • getloss(c,n) returns loss for configuration c and number of resources (e.g. epochs) n.

  • maxresource is the maximum number of resources any one configuration should be given.

  • reduction is an algorithm parameter (see paper), 3 is a good value.

source

Initialization

Knet.bilinearFunction.

Bilinear interpolation filter weights; used for initializing deconvolution layers.

Adapted from https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py#L33

Arguments:

T : Data Type

fw: Width upscale factor

fh: Height upscale factor

IN: Number of input filters

ON: Number of output filters

Example usage:

w = bilinear(Float32,2,2,128,128)

source
Knet.gaussianFunction.
gaussian(a...; mean=0.0, std=0.01)

Return a Gaussian array with a given mean and standard deviation. The a arguments are passed to randn.

source
Knet.xavierFunction.
xavier(a...)

Xavier initialization. The a arguments are passed to rand. See (Glorot and Bengio 2010) for a description. Caffe implements this slightly differently. Lasagne calls it GlorotUniform.

source

AutoGrad (advanced)

AutoGrad.getvalFunction.
getval(x)

Unbox x if it is a boxed value (Rec), otherwise return x.

source
@primitive fx g1 g2...

Define a new primitive operation for AutoGrad and (optionally) specify its gradients. Non-differentiable functions such as sign, and non-numeric functions such as size should be defined using the @zerograd macro instead.

Examples

@primitive sin(x::Number)
@primitive hypot(x1::Array,x2::Array),dy,y

@primitive sin(x::Number),dy  (dy*cos(x))
@primitive hypot(x1::Array,x2::Array),dy,y  (dy.*x1./y)  (dy.*x2./y)

The first example shows that fx is a typed method declaration. Julia supports multiple dispatch, i.e. a single function can have multiple methods with different arg types. AutoGrad takes advantage of this and supports multiple dispatch for primitives and gradients.

The second example specifies variable names for the output gradient dy and the output y after the method declaration which can be used in gradient expressions. Untyped, ellipsis and keyword arguments are ok as in f(a::Int,b,c...;d=1). Parametric methods such as f{T<:Number}(x::T) cannot be used.

The method declaration can optionally be followed by gradient expressions. The third and fourth examples show how gradients can be specified. Note that the parameters, the return variable and the output gradient of the original function can be used in the gradient expressions.

Under the hood

The @primitive macro turns the first example into:

let sin_r = recorder(sin)
    global sin
    sin{T<:Number}(x::Rec{T}) = sin_r(x)
end

This will cause calls to sin with a boxed argument (Rec{T<:Number}) to be recorded. The recorded operations are used by grad to construct a dynamic computational graph. With multiple arguments things are a bit more complicated. Here is what happens with the second example:

let hypot_r = recorder(hypot)
    global hypot
    hypot{T<:Array,S<:Array}(x1::Rec{T},x2::Rec{S})=hypot_r(x1,x2)
    hypot{T<:Array,S<:Array}(x1::Rec{T},x2::S)=hypot_r(x1,x2)
    hypot{T<:Array,S<:Array}(x1::T,x2::Rec{S})=hypot_r(x1,x2)
end

We want the recorder version to be called if any one of the arguments is a boxed Rec. There is no easy way to specify this in Julia, so the macro generates all 2^N-1 boxed/unboxed argument combinations.

In AutoGrad, gradients are defined using gradient methods that have the following signature:

f(Grad{i},dy,y,x...) => dx[i]

For the third example here is the generated gradient method:

sin{T<:Number}(::Type{Grad{1}}, dy, y, x::Rec{T})=(dy*cos(x))

For the last example a different gradient method is generated for each argument:

hypot{T<:Array,S<:Array}(::Type{Grad{1}},dy,y,x1::Rec{T},x2::Rec{S})=(dy.*x1./y)
hypot{T<:Array,S<:Array}(::Type{Grad{2}},dy,y,x1::Rec{T},x2::Rec{S})=(dy.*x2./y)

In fact @primitive generates four more definitions for the other boxed/unboxed argument combinations.

source
@zerograd f(args...; kwargs...)

Define f as an AutoGrad primitive operation with zero gradient.

Example:

@zerograd floor(x::Float32)

@zerograd allows f to handle boxed Rec inputs by unboxing them like a @primitive, but unlike @primitive it does not record its actions or return a boxed Rec result. Some functions, like sign(), have zero gradient. Others, like length() have discrete or constant outputs. These need to handle Rec inputs, but do not need to record anything and can return regular values. Their output can be treated like a constant in the program. Use the @zerograd macro for those. Note that kwargs are NOT unboxed.

source

Function Index