Reference
Contents
AutoGrad
AutoGrad.grad
— Function.grad(fun, argnum=1)
Take a function fun(X...)->Y
and return another function gfun(X...)->dXi
which computes its gradient with respect to positional argument number argnum
. The function fun
should be scalar-valued. The returned function gfun
takes the same arguments as fun
, but returns the gradient instead. The gradient has the same type and size as the target argument which can be a Number, Array, Tuple, or Dict.
AutoGrad.gradloss
— Function.gradloss(fun, argnum=1)
Another version of grad
where the generated function returns a (gradient,value) pair.
AutoGrad.gradcheck
— Function.gradcheck(f, x...; kwargs...)
Numerically check the gradient of f(x...)
and return a boolean result.
Each argument can be a Number, Array, Tuple or Dict which in turn can contain other Arrays etc. Only 10 random entries in each large numeric array are checked by default. If the output of f
is not a number, we check the gradient of sum(f(x...))
.
Keywords
args=:
: the argument indices to check gradients with respect to. Could be an array or range of indices or a single index. By default all arguments that have alength
method are checked.kw=()
: keyword arguments to be passed tof
.nsample=10
: number of random entries from each numeric array in gradientdw=(grad(f))(w,x...;o...)
compared to their numerical estimates.atol=rtol=0.01
: tolerance parameters. Seeisapprox
for their meaning.delta=0.0001
: step size for numerical gradient calculation.verbose=1
: 0 prints nothing, 1 shows failing tests, 2 shows all tests.
KnetArray
Knet.KnetArray
— Type.KnetArray{T}(undef,dims)
KnetArray(a::AbstractArray)
Array(k::KnetArray)
Container for GPU arrays that supports most of the AbstractArray interface. The constructor allocates a KnetArray in the currently active device, as specified by gpu()
. KnetArrays and Arrays can be converted to each other as shown above, which involves copying to and from the GPU memory. Only Float32/64 KnetArrays are fully supported.
Important differences from the alternative CudaArray are: (1) a custom memory manager that minimizes the number of calls to the slow cudaMalloc by reusing already allocated but garbage collected GPU pointers. (2) a custom getindex that handles ranges such as a[5:10]
as views with shared memory instead of copies. (3) custom CUDA kernels that implement elementwise, broadcasting, and reduction operations.
Supported functions:
Indexing: getindex, setindex! with the following index types:
- 1-D: Real, Colon, OrdinalRange, AbstractArray{Real}, AbstractArray{Bool}, CartesianIndex, AbstractArray{CartesianIndex}, EmptyArray, KnetArray{Int32} (low level), KnetArray{0/1} (using float for BitArray) (1-D includes linear indexing of multidimensional arrays)
- 2-D: (Colon,Union{Real,Colon,OrdinalRange,AbstractVector{Real},AbstractVector{Bool},KnetVector{Int32}}), (Union{Real,AbstractUnitRange,Colon}...) (in any order)
- N-D: (Real...)
Array operations: ==, !=, cat, convert, copy, copyto!, deepcopy, display, eachindex, eltype, endof, fill!, first, hcat, isapprox, isempty, length, ndims, one, ones, pointer, rand!, randn!, reshape, similar, size, stride, strides, summary, vcat, vec, zero. (cat(x,y,dims=i) supported for i=1,2.)
Math operators: (-), abs, abs2, acos, acosh, asin, asinh, atan, atanh, cbrt, ceil, cos, cosh, cospi, erf, erfc, erfcinv, erfcx, erfinv, exp, exp10, exp2, expm1, floor, log, log10, log1p, log2, round, sign, sin, sinh, sinpi, sqrt, tan, tanh, trunc
Broadcasting operators: (.*), (.+), (.-), (./), (.<), (.<=), (.!=), (.==), (.>), (.>=), (.^), max, min. (Boolean operators generate outputs with same type as inputs; no support for KnetArray{Bool}.)
Reduction operators: countnz, maximum, mean, minimum, prod, sum, sumabs, sumabs2, norm.
Linear algebra: (*), axpy!, permutedims (up to 5D), transpose
Knet extras: relu, sigm, invx, logp, logsumexp, conv4, pool, deconv4, unpool, mat, update! (Only 4D/5D, Float32/64 KnetArrays support conv4, pool, deconv4, unpool)
Memory management
Knet models do not overwrite arrays which need to be preserved for gradient calculation. This leads to a lot of allocation and regular GPU memory allocation is prohibitively slow. Fortunately most models use identically sized arrays over and over again, so we can minimize the number of actual allocations by reusing preallocated but garbage collected pointers.
When Julia gc reclaims a KnetArray, a special finalizer keeps its pointer in a table instead of releasing the memory. If an array with the same size in bytes is later requested, the same pointer is reused. The exact algorithm for allocation is:
Try to find a previously allocated and garbage collected pointer in the current device. (0.5 μs)
If not available, try to allocate a new array using cudaMalloc. (10 μs)
If not successful, try running gc() and see if we get a pointer of the right size. (75 ms, but this should be amortized over all reusable pointers that become available due to the gc)
Finally if all else fails, clean up all saved pointers in the current device using cudaFree and try allocation one last time. (25-70 ms, however this causes the elimination of all reusable pointers)
Utilities
Knet.accuracy
— Function.accuracy(scores, answers; dims=1, average=true)
Given an unnormalized scores
matrix and an Integer
array of correct answers
, return the ratio of instances where the correct answer has the maximum score. dims=1
means instances are in columns, dims=2
means instances are in rows. Use average=false
to return the number of correct answers instead of the ratio.
accuracy(model, data, predict; average=true)
Compute accuracy(predict(model,x), y)
for (x,y)
in data
and return the ratio (if average=true) or the count (if average=false) of correct answers.
Knet.dir
— Function.Knet.dir(path...)
Construct a path relative to Knet root.
Example
julia> Knet.dir("examples","mnist.jl")
"/home/dyuret/.julia/v0.5/Knet/examples/mnist.jl"
Knet.dropout
— Function.dropout(x, p)
Given an array x
and probability 0<=p<=1
, just return x
if p==0
, or return an array y
in which each element is 0 with probability p
or x[i]/(1-p)
with probability 1-p
. Use seed::Number
to set the random number seed for reproducible results. See (Srivastava et al. 2014) for a reference.
Knet.gpu
— Function.gpu()
returns the id of the active GPU device or -1 if none are active.
gpu(true)
resets all GPU devices and activates the one with the most available memory.
gpu(false)
resets and deactivates all GPU devices.
gpu(d::Int)
activates the GPU device d
if 0 <= d < gpuCount()
, otherwise deactivates devices.
gpu(true/false)
resets all devices. If there are any allocated KnetArrays their pointers will be left dangling. Thus gpu(true/false)
should only be used during startup. If you want to suspend GPU use temporarily, use gpu(-1)
.
gpu(d::Int)
does not reset the devices. You can select a previous device and find allocated memory preserved. However trying to operate on arrays of an inactive device will result in error.
Knet.invx
— Function.invx(x) = (1./x)
Knet.gc
— Function.Knet.gc(dev=gpu())
cudaFree all pointers allocated on device dev
that were previously allocated and garbage collected. Normally Knet holds on to all garbage collected pointers for reuse. Try this if you run out of GPU memory.
Knet.logp
— Function.logp(x;[dims])
Treat entries in x
as as unnormalized log probabilities and return normalized log probabilities.
dims
is an optional argument, if not specified the normalization is over the whole x
, otherwise the normalization is performed over the given dimensions. In particular, if x
is a matrix, dims=1
normalizes columns of x
and dims=2
normalizes rows of x
.
Knet.logsumexp
— Function.logsumexp(x;dims=:)
Compute log(sum(exp(x);dims))
in a numerically stable manner.
dims
is an optional argument, if not specified the summation is over the whole x
, otherwise the summation is performed over the given dimensions. In particular if x
is a matrix, dims=1
sums columns of x
and dims=2
sums rows of x
.
Knet.minibatch
— Function.minibatch(x, y, batchsize; shuffle, partial, xtype, ytype)
Return an iterable of minibatches [(xi,yi)...] given data tensors x, y and batchsize. The last dimension of x and y should match and give the number of instances. Keyword arguments:
shuffle=false
: Shuffle the instances before minibatching.partial=false
: If true include the last partial minibatch < batchsize.xtype=typeof(x)
: Convert xi in minibatches to this type.ytype=typeof(y)
: Convert yi in minibatches to this type.
minibatch(x, batchsize; shuffle, partial, xtype, ytype)
Return an iterable of minibatches [x1,x2,...] given data tensor x and batchsize. The last dimension of x gives the number of instances. Keyword arguments:
shuffle=false
: Shuffle the instances before minibatching.partial=false
: If true include the last partial minibatch < batchsize.xtype=typeof(x)
: Convert xi in minibatches to this type.
Knet.nll
— Function.nll(scores, answers; dims=1, average=true)
Given an unnormalized scores
matrix and an Integer
array of correct answers
, return the per-instance negative log likelihood. dims=1
means instances are in columns, dims=2
means instances are in rows. Use average=false
to return the sum instead of per-instance average.
nll(f, data; average=true)
Compute nll(f(x), y)
for (x,y)
in data
and return the per-instance average (if average=true) or total (if average=false) negative log likelihood.
Knet.relu
— Function.relu(x) = max(0,x)
Knet.seed!
— Function.Knet.seed!(n::Integer)
Run seed!(n) on both cpu and gpu.
Knet.sigm
— Function.sigm(x) = (1./(1+exp(-x)))
Convolution and Pooling
Knet.conv4
— Function.conv4(w, x; kwargs...)
Execute convolutions or cross-correlations using filters specified with w
over tensor x
.
Currently KnetArray{Float32/64,4/5} and Array{Float32/64,4} are supported as w
and x
. If w
has dimensions (W1,W2,...,I,O)
and x
has dimensions (X1,X2,...,I,N)
, the result y
will have dimensions (Y1,Y2,...,O,N)
where
Yi=1+floor((Xi+2*padding[i]-Wi)/stride[i])
Here I
is the number of input channels, O
is the number of output channels, N
is the number of instances, and Wi,Xi,Yi
are spatial dimensions. padding
and stride
are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.
Keywords
padding=0
: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.stride=1
: the number of elements to slide to reach the next filtering window.upscale=1
: upscale factor for each dimension.mode=0
: 0 for convolution and 1 for cross-correlation.alpha=1
: can be used to scale the result.handle
: handle to a previously created cuDNN context. Defaults to a Knet allocated handle.
Knet.deconv4
— Function.y = deconv4(w, x; kwargs...)
Simulate 4-D deconvolution by using transposed convolution operation. Its forward pass is equivalent to backward pass of a convolution (gradients with respect to input tensor). Likewise, its backward pass (gradients with respect to input tensor) is equivalent to forward pass of a convolution. Since it swaps forward and backward passes of convolution operation, padding and stride options belong to output tensor. See this report for further explanation.
Currently KnetArray{Float32/64,4} and Array{Float32/64,4} are supported as w
and x
. If w
has dimensions (W1,W2,...,O,I)
and x
has dimensions (X1,X2,...,I,N)
, the result y
will have dimensions (Y1,Y2,...,O,N)
where
Yi = Wi+stride[i](Xi-1)-2padding[i]
Here I is the number of input channels, O is the number of output channels, N is the number of instances, and Wi,Xi,Yi are spatial dimensions. padding and stride are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.
Keywords
padding=0
: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.stride=1
: the number of elements to slide to reach the next filtering window.mode=0
: 0 for convolution and 1 for cross-correlation.alpha=1
: can be used to scale the result.handle
: handle to a previously created cuDNN context. Defaults to a Knet allocated handle.
Knet.mat
— Function.mat(x)
Reshape x into a two-dimensional matrix.
This is typically used when turning the output of a 4-D convolution result into a 2-D input for a fully connected layer. For 1-D inputs returns reshape(x, (length(x),1))
. For inputs with more than two dimensions of size (X1,X2,...,XD)
, returns
reshape(x, (X1*X2*...*X[D-1],XD))
Knet.pool
— Function.pool(x; kwargs...)
Compute pooling of input values (i.e., the maximum or average of several adjacent values) to produce an output with smaller height and/or width.
Currently 4 or 5 dimensional KnetArrays with Float32
or Float64
entries are supported. If x
has dimensions (X1,X2,...,I,N)
, the result y
will have dimensions (Y1,Y2,...,I,N)
where
Yi=1+floor((Xi+2*padding[i]-window[i])/stride[i])
Here I
is the number of input channels, N
is the number of instances, and Xi,Yi
are spatial dimensions. window
, padding
and stride
are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.
Keywords:
window=2
: the pooling window size for each dimension.padding=0
: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.stride=window
: the number of elements to slide to reach the next pooling window.mode=0
: 0 for max, 1 for average including padded values, 2 for average excluding padded values.maxpoolingNanOpt=0
: Nan numbers are not propagated if 0, they are propagated if 1.alpha=1
: can be used to scale the result.handle
: Handle to a previously created cuDNN context. Defaults to a Knet allocated handle.
Knet.unpool
— Function.Unpooling; reverse
of pooling.
x == pool(unpool(x;o...); o...)
Recurrent neural networks
Knet.rnninit
— Function.rnninit(inputSize, hiddenSize; opts...)
Return an (r,w)
pair where r
is a RNN struct and w
is a single weight array that includes all matrices and biases for the RNN. Keyword arguments:
rnnType=:lstm
Type of RNN: One of :relu, :tanh, :lstm, :gru.numLayers=1
: Number of RNN layers.bidirectional=false
: Create a bidirectional RNN iftrue
.dropout=0.0
: Dropout probability. Ignored ifnumLayers==1
.skipInput=false
: Do not multiply the input with a matrix iftrue
.dataType=Float32
: Data type to use for weights.algo=0
: Algorithm to use, see CUDNN docs for details.seed=0
: Random number seed. Usestime()
if 0.winit=xavier
: Weight initialization method for matrices.binit=zeros
: Weight initialization method for bias vectors.usegpu=(gpu()>=0)
: GPU used by default if one exists.
RNNs compute the output h[t] for a given iteration from the recurrent input h[t-1] and the previous layer input x[t] given matrices W, R and biases bW, bR from the following equations:
:relu
and :tanh
: Single gate RNN with activation function f:
h[t] = f(W * x[t] .+ R * h[t-1] .+ bW .+ bR)
:gru
: Gated recurrent unit:
i[t] = sigm(Wi * x[t] .+ Ri * h[t-1] .+ bWi .+ bRi) # input gate
r[t] = sigm(Wr * x[t] .+ Rr * h[t-1] .+ bWr .+ bRr) # reset gate
n[t] = tanh(Wn * x[t] .+ r[t] .* (Rn * h[t-1] .+ bRn) .+ bWn) # new gate
h[t] = (1 - i[t]) .* n[t] .+ i[t] .* h[t-1]
:lstm
: Long short term memory unit with no peephole connections:
i[t] = sigm(Wi * x[t] .+ Ri * h[t-1] .+ bWi .+ bRi) # input gate
f[t] = sigm(Wf * x[t] .+ Rf * h[t-1] .+ bWf .+ bRf) # forget gate
o[t] = sigm(Wo * x[t] .+ Ro * h[t-1] .+ bWo .+ bRo) # output gate
n[t] = tanh(Wn * x[t] .+ Rn * h[t-1] .+ bWn .+ bRn) # new gate
c[t] = f[t] .* c[t-1] .+ i[t] .* n[t] # cell output
h[t] = o[t] .* tanh(c[t])
Knet.rnnforw
— Function.rnnforw(r, w, x[, hx, cx]; batchSizes, hy, cy)
Returns a tuple (y,hyout,cyout,rs) given rnn r
, weights w
, input x
and optionally the initial hidden and cell states hx
and cx
(cx
is only used in LSTMs). r
and w
should come from a previous call to rnninit
. Both hx
and cx
are optional, they are treated as zero arrays if not provided. The output y
contains the hidden states of the final layer for each time step, hyout
and cyout
give the final hidden and cell states for all layers, rs
is a buffer the RNN needs for its gradient calculation.
The boolean keyword arguments hy
and cy
control whether hyout
and cyout
will be output. By default hy = (hx!=nothing)
and cy = (cx!=nothing && r.mode==2)
, i.e. a hidden state will be output if one is provided as input and for cell state we also require an LSTM. If hy
/cy
is false
, hyout
/cyout
will be nothing
. batchSizes
can be an integer array that specifies non-uniform batch sizes as explained below. By default batchSizes=nothing
and the same batch size, size(x,2)
, is used for all time steps.
The input and output dimensions are:
x
: (X,[B,T])y
: (H/2H,[B,T])hx
,cx
,hyout
,cyout
: (H,B,L/2L)batchSizes
:nothing
orVector{Int}(T)
where X is inputSize, H is hiddenSize, B is batchSize, T is seqLength, L is numLayers. x
can be 1, 2, or 3 dimensional. If batchSizes==nothing
, a 1-D x
represents a single instance, a 2-D x
represents a single minibatch, and a 3-D x
represents a sequence of identically sized minibatches. If batchSizes
is an array of (non-increasing) integers, it gives us the batch size for each time step in the sequence, in which case sum(batchSizes)
should equal div(length(x),size(x,1))
. y
has the same dimensionality as x
, differing only in its first dimension, which is H if the RNN is unidirectional, 2H if bidirectional. Hidden vectors hx
, cx
, hyout
, cyout
all have size (H,B1,L) for unidirectional RNNs, and (H,B1,2L) for bidirectional RNNs where B1 is the size of the first minibatch.
Knet.rnnparam
— Function.rnnparam(r::RNN, w, layer, id, param)
Return a single weight matrix or bias vector as a slice of w.
Valid layer
values:
- For unidirectional RNNs 1:numLayers
- For bidirectional RNNs 1:2*numLayers, forw and back layers alternate.
Valid id
values:
- For RELU and TANH RNNs, input = 1, hidden = 2.
- For GRU reset = 1,4; update = 2,5; newmem = 3,6; 1:3 for input, 4:6 for hidden
- For LSTM inputgate = 1,5; forget = 2,6; newmem = 3,7; output = 4,8; 1:4 for input, 5:8 for hidden
Valid param
values:
- Return the weight matrix (transposed!) if
param==1
. - Return the bias vector if
param==2
.
The effect of skipInput: Let I=1 for RELU/TANH, 1:3 for GRU, 1:4 for LSTM
- For skipInput=false (default), rnnparam(r,w,1,I,1) is a (inputSize,hiddenSize) matrix.
- For skipInput=true, rnnparam(r,w,1,I,1) is
nothing
. - For bidirectional, the same applies to rnnparam(r,w,2,I,1): the first back layer.
Knet.rnnparams
— Function.rnnparams(r::RNN, w)
Split w into individual parameters and return them as an array.
The order of params returned (subject to change):
- All weight matrices come before all bias vectors.
- Matrices and biases are sorted lexically based on (layer,id).
- See @doc rnnparam for valid layer and id values.
- Input multiplying matrices are
nothing
if r.inputMode = 1.
Batch Normalization
Knet.bnmoments
— Function.bnmoments(;momentum=0.1, mean=nothing, var=nothing, meaninit=zeros, varinit=ones)
can be used directly load moments from data. meaninit
and varinit
are called if mean
and var
are nothing. Type and size of the mean
and var
are determined automatically from the inputs in the batchnorm
calls. A BNMoments
object is returned.
BNMoments
A high-level data structure used to store running mean and running variance of batch normalization with the following fields:
momentum::AbstractFloat
: A real number between 0 and 1 to be used as the scale of last mean and variance. The existing running mean or variance is multiplied by (1-momentum).
mean
: The running mean.
var
: The running variance.
meaninit
: The function used for initialize the running mean. Should either be nothing
or of the form (eltype, dims...)->data
. zeros
is a good option.
varinit
: The function used for initialize the running variance. Should either be nothing
or (eltype, dims...)->data
. ones
is a good option.
Knet.bnparams
— Function.bnparams(etype, channels)
creates a single 1d array that contains both scale and bias of batchnorm, where the first half is scale and the second half is bias.
bnparams(channels)
calls bnparams
with etype=Float64
, following Julia convention
Knet.batchnorm
— Function.batchnorm(x[, moments, params]; kwargs...)
performs batch normalization to x
with optional scaling factor and bias stored in params
.
2d, 4d and 5d inputs are supported. Mean and variance are computed over dimensions (2,), (1,2,4) and (1,2,3,5) for 2d, 4d and 5d arrays, respectively.
moments
stores running mean and variance to be used in testing. It is optional in the training mode, but mandatory in the test mode. Training and test modes are controlled by the training
keyword argument.
params
stores the optional affine parameters gamma and beta. bnparams
function can be used to initialize params
.
Example
# Inilization, C is an integer
moments = bnmoments()
params = bnparams(C)
...
# size(x) -> (H, W, C, N)
y = batchnorm(x, moments, params)
# size(y) -> (H, W, C, N)
Keywords
eps=1e-5
: The epsilon parameter added to the variance to avoid division by 0.
training
: When training
is true, the mean and variance of x
are used and moments
argument is modified if it is provided. When training
is false, mean and variance stored in the moments
argument are used. Default value is true
when at least one of x
and params
is AutoGrad.Value
, false
otherwise.
Optimization methods
Knet.update!
— Function.update!(weights, gradients, params)
update!(weights, gradients; lr=0.001, gclip=0)
Update the weights
using their gradients
and the optimization algorithm parameters specified by params
. The 2-arg version defaults to the Sgd
algorithm with learning rate lr
and gradient clip gclip
. gclip==0
indicates no clipping. The weights
and possibly gradients
and params
are modified in-place.
weights
can be an individual numeric array or a collection of arrays represented by an iterator or dictionary. In the individual case, gradients
should be a similar numeric array of size(weights)
and params
should be a single object. In the collection case, each individual weight array should have a corresponding params object. This way different weight arrays can have their own optimization state, different learning rates, or even different optimization algorithms running in parallel. In the iterator case, gradients
and params
should be iterators of the same length as weights
with corresponding elements. In the dictionary case, gradients
and params
should be dictionaries with the same keys as weights
.
Individual optimization parameters can be one of the following types. The keyword arguments for each type's constructor and their default values are listed as well.
Sgd
(;lr=0.001, gclip=0)
Momentum
(;lr=0.001, gclip=0, gamma=0.9)
Nesterov
(;lr=0.001, gclip=0, gamma=0.9)
Rmsprop
(;lr=0.001, gclip=0, rho=0.9, eps=1e-6)
Adagrad
(;lr=0.1, gclip=0, eps=1e-6)
Adadelta
(;lr=0.01, gclip=0, rho=0.9, eps=1e-6)
Adam
(;lr=0.001, gclip=0, beta1=0.9, beta2=0.999, eps=1e-8)
Example:
w = rand(d) # an individual weight array
g = lossgradient(w) # gradient g has the same shape as w
update!(w, g) # update w in-place with Sgd()
update!(w, g; lr=0.1) # update w in-place with Sgd(lr=0.1)
update!(w, g, Sgd(lr=0.1)) # update w in-place with Sgd(lr=0.1)
w = (rand(d1), rand(d2)) # a tuple of weight arrays
g = lossgradient2(w) # g will also be a tuple
p = (Adam(), Sgd()) # p has params for each w[i]
update!(w, g, p) # update each w[i] in-place with g[i],p[i]
w = Any[rand(d1), rand(d2)] # any iterator can be used
g = lossgradient3(w) # g will be similar to w
p = Any[Adam(), Sgd()] # p should be an iterator of same length
update!(w, g, p) # update each w[i] in-place with g[i],p[i]
w = Dict(:a => rand(d1), :b => rand(d2)) # dictionaries can be used
g = lossgradient4(w)
p = Dict(:a => Adam(), :b => Sgd())
update!(w, g, p)
Knet.optimizers
— Function.optimizers(model, otype; options...)
Given parameters of a model
, initialize and return corresponding optimization parameters for a given optimization type otype
and optimization options options
. This is useful because each numeric array in model needs its own distinct optimization parameter. optimizers
makes the creation of optimization parameters that parallel model parameters easy when all of them use the same type and options.
Knet.Adadelta
— Type.Adadelta(;lr=0.01, gclip=0, rho=0.9, eps=1e-6)
update!(w,g,p::Adadelta)
Container for parameters of the Adadelta optimization algorithm used by update!
.
Adadelta is an extension of Adagrad that tries to prevent the decrease of the learning rates to zero as training progresses. It scales the learning rate based on the accumulated gradients like Adagrad and holds the acceleration term like Momentum. It updates the weights with the following formulas:
G = (1-rho) * g .^ 2 + rho * G
update = g .* sqrt(delta + eps) ./ sqrt(G + eps)
w = w - lr * update
delta = rho * delta + (1-rho) * update .^ 2
where w
is the weight, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, G
is an array with the same size and type of w
and holds the sum of the squares of the gradients. eps
is a small constant to prevent a zero value in the denominator. rho
is the momentum parameter and delta
is an array with the same size and type of w
and holds the sum of the squared updates.
If norm(g) > gclip > 0
, g
is scaled so that its norm is equal to gclip
. If gclip==0
no scaling takes place.
Reference: Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method.
Knet.Adagrad
— Type.Adagrad(;lr=0.1, gclip=0, eps=1e-6)
update!(w,g,p::Adagrad)
Container for parameters of the Adagrad optimization algorithm used by update!
.
Adagrad is one of the methods that adapts the learning rate to each of the weights. It stores the sum of the squares of the gradients to scale the learning rate. The learning rate is adapted for each weight by the value of current gradient divided by the accumulated gradients. Hence, the learning rate is greater for the parameters where the accumulated gradients are small and the learning rate is small if the accumulated gradients are large. It updates the weights with the following formulas:
G = G + g .^ 2
w = w - g .* lr ./ sqrt(G + eps)
where w
is the weight, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, G
is an array with the same size and type of w
and holds the sum of the squares of the gradients. eps
is a small constant to prevent a zero value in the denominator.
If norm(g) > gclip > 0
, g
is scaled so that its norm is equal to gclip
. If gclip==0
no scaling takes place.
Reference: Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159.
Knet.Adam
— Type.Adam(;lr=0.001, gclip=0, beta1=0.9, beta2=0.999, eps=1e-8)
update!(w,g,p::Adam)
Container for parameters of the Adam optimization algorithm used by update!
.
Adam is one of the methods that compute the adaptive learning rate. It stores accumulated gradients (first moment) and the sum of the squared of gradients (second). It scales the first and second moment as a function of time. Here is the update formulas:
m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g .* g
mhat = m ./ (1 - beta1 ^ t)
vhat = v ./ (1 - beta2 ^ t)
w = w - (lr / (sqrt(vhat) + eps)) * mhat
where w
is the weight, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, m
is an array with the same size and type of w
and holds the accumulated gradients. v
is an array with the same size and type of w
and holds the sum of the squares of the gradients. eps
is a small constant to prevent a zero denominator. beta1
and beta2
are the parameters to calculate bias corrected first and second moments. t
is the update count.
If norm(g) > gclip > 0
, g
is scaled so that its norm is equal to gclip
. If gclip==0
no scaling takes place.
Reference: Kingma, D. P., & Ba, J. L. (2015). Adam: a Method for Stochastic Optimization. International Conference on Learning Representations, 1–13.
Knet.Momentum
— Type.Momentum(;lr=0.001, gclip=0, gamma=0.9)
update!(w,g,p::Momentum)
Container for parameters of the Momentum optimization algorithm used by update!
.
The Momentum method tries to accelerate SGD by adding a velocity term to the update. This also decreases the oscillation between successive steps. It updates the weights with the following formulas:
velocity = gamma * velocity + lr * g
w = w - velocity
where w
is a weight array, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, gamma
is the momentum parameter, velocity
is an array with the same size and type of w
and holds the accelerated gradients.
If norm(g) > gclip > 0
, g
is scaled so that its norm is equal to gclip
. If gclip==0
no scaling takes place.
Reference: Qian, N. (1999). On the momentum term in gradient descent learning algorithms. Neural Networks : The Official Journal of the International Neural Network Society, 12(1), 145–151.
Knet.Nesterov
— Type.Nesterov(; lr=0.001, gclip=0, gamma=0.9)
update!(w,g,p::Momentum)
Container for parameters of Nesterov's momentum optimization algorithm used by update!
.
It is similar to standard Momentum
but with a slightly different update rule:
velocity = gamma * velocity_old - lr * g
w = w_old - velocity_old + (1+gamma) * velocity
where w
is a weight array, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, gamma
is the momentum parameter, velocity
is an array with the same size and type of w
and holds the accelerated gradients.
If norm(g) > gclip > 0
, g
is scaled so that its norm is equal to gclip
. If gclip == 0
no scaling takes place.
Reference Implementation : Yoshua Bengio, Nicolas Boulanger-Lewandowski and Razvan P ascanu
Knet.Rmsprop
— Type.Rmsprop(;lr=0.001, gclip=0, rho=0.9, eps=1e-6)
update!(w,g,p::Rmsprop)
Container for parameters of the Rmsprop optimization algorithm used by update!
.
Rmsprop scales the learning rates by dividing the root mean squared of the gradients. It updates the weights with the following formula:
G = (1-rho) * g .^ 2 + rho * G
w = w - lr * g ./ sqrt(G + eps)
where w
is the weight, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, G
is an array with the same size and type of w
and holds the sum of the squares of the gradients. eps
is a small constant to prevent a zero value in the denominator. rho
is the momentum parameter and delta
is an array with the same size and type of w
and holds the sum of the squared updates.
If norm(g) > gclip > 0
, g
is scaled so that its norm is equal to gclip
. If gclip==0
no scaling takes place.
Reference: Tijmen Tieleman and Geoffrey Hinton (2012). "Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude." COURSERA: Neural Networks for Machine Learning 4.2.
Knet.Sgd
— Type.Sgd(;lr=0.001,gclip=0)
update!(w,g,p::Sgd)
update!(w,g;lr=0.001)
Container for parameters of the Stochastic gradient descent (SGD) optimization algorithm used by update!
.
SGD is an optimization technique to minimize an objective function by updating its weights in the opposite direction of their gradient. The learning rate (lr) determines the size of the step. SGD updates the weights with the following formula:
w = w - lr * g
where w
is a weight array, g
is the gradient of the loss function w.r.t w
and lr
is the learning rate.
If norm(g) > gclip > 0
, g
is scaled so that its norm is equal to gclip
. If gclip==0
no scaling takes place.
SGD is used by default if no algorithm is specified in the two argument version of update!
[@ref].
Hyperparameter optimization
Knet.goldensection
— Function.goldensection(f,n;kwargs) => (fmin,xmin)
Find the minimum of f
using concurrent golden section search in n
dimensions. See Knet.goldensection_demo()
for an example.
f
is a function from a Vector{Float64}
of length n
to a Number
. It can return NaN
for out of range inputs. Goldensection will always start with a zero vector as the initial input to f
, and the initial step size will be 1 in each dimension. The user should define f
to scale and shift this input range into a vector meaningful for their application. For positive inputs like learning rate or hidden size, you can use a transformation such as x0*exp(x)
where x
is a value goldensection
passes to f
and x0
is your initial guess for this value. This will effectively start the search at x0
, then move with multiplicative steps.
I designed this algorithm combining ideas from Golden Section Search and Hill Climbing Search. It essentially runs golden section search concurrently in each dimension, picking the next step based on estimated gain.
Keyword arguments
dxmin=0.1
: smallest step size.accel=φ
: acceleration rate. Golden ratioφ=1.618...
is best.verbose=false
: usetrue
to print individual steps.history=[]
: cache of[(x,f(x)),...]
function evaluations.
Knet.hyperband
— Function.hyperband(getconfig, getloss, maxresource=27, reduction=3)
Hyperparameter optimization using the hyperband algorithm from (Lisha et al. 2016). You can try a simple MNIST example using Knet.hyperband_demo()
.
Arguments
getconfig()
returns random configurations with a user defined type and distribution.getloss(c,n)
returns loss for configurationc
and number of resources (e.g. epochs)n
.maxresource
is the maximum number of resources any one configuration should be given.reduction
is an algorithm parameter (see paper), 3 is a good value.
Initialization
Knet.bilinear
— Function.Bilinear interpolation filter weights; used for initializing deconvolution layers.
Adapted from https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py#L33
Arguments:
T
: Data Type
fw
: Width upscale factor
fh
: Height upscale factor
IN
: Number of input filters
ON
: Number of output filters
Example usage:
w = bilinear(Float32,2,2,128,128)
Knet.gaussian
— Function.gaussian(a...; mean=0.0, std=0.01)
Return a Gaussian array with a given mean and standard deviation. The a
arguments are passed to randn
.
Knet.xavier
— Function.xavier(a...)
Xavier initialization. The a
arguments are passed to rand
. See (Glorot and Bengio 2010) for a description. Caffe implements this slightly differently. Lasagne calls it GlorotUniform
.
AutoGrad (advanced)
AutoGrad.getval
— Function.getval(x)
Unbox x
if it is a boxed value (Rec
), otherwise return x
.
AutoGrad.@primitive
— Macro.@primitive fx g1 g2...
Define a new primitive operation for AutoGrad and (optionally) specify its gradients. Non-differentiable functions such as sign
, and non-numeric functions such as size
should be defined using the @zerograd macro instead.
Examples
@primitive sin(x::Number)
@primitive hypot(x1,x2),dy,y
@primitive sin(x::Number),dy (dy.*cos(x))
@primitive hypot(x1,x2),dy,y (dy.*x1./y) (dy.*x2./y)
The first example shows that fx
is a typed method declaration. Julia supports multiple dispatch, i.e. a single function can have multiple methods with different arg types. AutoGrad takes advantage of this and supports multiple dispatch for primitives and gradients.
The second example specifies variable names for the output gradient dy
and the output y
after the method declaration which can be used in gradient expressions. Untyped, ellipsis and keyword arguments are ok as in f(a::Int,b,c...;d=1)
. Parametric methods such as f(x::T) where {T<:Number}
cannot be used.
The method declaration can optionally be followed by gradient expressions. The third and fourth examples show how gradients can be specified. Note that the parameters, the return variable and the output gradient of the original function can be used in the gradient expressions.
Under the hood
The @primitive macro turns the first example into:
sin(x::Rec{T}) where {T<:Number} = forw(sin, x)
This will cause calls to sin
with a boxed argument (Rec{T<:Number}
) to be recorded. The recorded operations are used by AutoGrad to construct a dynamic computational graph. With multiple arguments things are a bit more complicated. Here is what happens with the second example:
hypot(x1::Rec{S}, x2::Rec{T}) where {S<:Any,T<:Any} = forw(hypot, x1, x2)
hypot(x1::S, x2::Rec{T}) where {S<:Any,T<:Any} = forw(hypot, x1, x2)
hypot(x1::Rec{S}, x2::T) where {S<:Any,T<:Any} = forw(hypot, x1, x2)
We want the forw method to be called if any one of the arguments is a boxed Rec
. There is no easy way to specify this in Julia, so the macro generates all 2^N-1 boxed/unboxed argument combinations.
In AutoGrad, gradients are defined using gradient methods that have the following pattern:
back(f,Val(i),dy,y,x...) => dx[i]
For the third example here is the generated gradient method:
back(::typeof(sin), ::Val{1}, dy, y, x::Rec{T}) where {T<:Number} = dy .* cos(x)
For the last example a different gradient method is generated for each argument:
back(::typeof(hypot), ::Val{1}, dy, y, x1::Rec{S}, x2::Rec{T}) where {S<:Any,T<:Any} = (dy .* x1) ./ y
back(::typeof(hypot), ::Val{2}, dy, y, x1::Rec{S}, x2::Rec{T}) where {S<:Any,T<:Any} = (dy .* x2) ./ y
In fact @primitive generates four more definitions for the other boxed/unboxed argument combinations.
Broadcasting
Broadcasting is handled by extra forw
and back
methods. In broadcast.jl
we define:
broadcasted(f, x::Rec) = forw(broadcast,f,x)
and similar methods that match any function f
, so that when a boxed value is in a broadcasting operation forw
is called. The @primitive
macro defines the back
method for broadcasting of a particular primitive:
back(::typeof(broadcast), ::Val{2}, dy, y, ::typeof(sin), x::Rec{T}) where {T<:Number} = dy .* cos(x)
If you do not want the back method for broadcasting, you can use the @primitive1
macro which omits this final definition.
AutoGrad.@zerograd
— Macro.@zerograd f(args...; kwargs...)
Define f
as an AutoGrad primitive operation with zero gradient.
Example:
@zerograd floor(x::Float32)
@zerograd
allows f
to handle boxed Rec
inputs by unboxing them like a @primitive
, but unlike @primitive
it does not record its actions or return a boxed Rec
result. Some functions, like sign()
, have zero gradient. Others, like length()
have discrete or constant outputs. These need to handle Rec
inputs, but do not need to record anything and can return regular values. Their output can be treated like a constant in the program. Use the @zerograd
macro for those. Use the @zerograd1
variant if you don't want to define the broadcasting version. Note that kwargs
are NOT unboxed.
Function Index
Knet.Adadelta
Knet.Adagrad
Knet.Adam
Knet.KnetArray
Knet.Momentum
Knet.Nesterov
Knet.Rmsprop
Knet.Sgd
AutoGrad.getval
AutoGrad.grad
AutoGrad.gradcheck
AutoGrad.gradloss
Knet.accuracy
Knet.batchnorm
Knet.bilinear
Knet.bnmoments
Knet.bnparams
Knet.conv4
Knet.deconv4
Knet.dir
Knet.dropout
Knet.gaussian
Knet.gc
Knet.goldensection
Knet.gpu
Knet.hyperband
Knet.invx
Knet.logp
Knet.logsumexp
Knet.mat
Knet.minibatch
Knet.nll
Knet.optimizers
Knet.pool
Knet.relu
Knet.rnnforw
Knet.rnninit
Knet.rnnparam
Knet.rnnparams
Knet.seed!
Knet.sigm
Knet.unpool
Knet.update!
Knet.xavier
AutoGrad.@primitive
AutoGrad.@zerograd