Reference
Contents
AutoGrad
AutoGrad.grad
— Function.grad(fun, argnum=1)
Take a function fun(X...)->Y
and return another function gfun(X...)->dXi
which computes its gradient with respect to positional argument number argnum
. The function fun
should be scalar-valued. The returned function gfun
takes the same arguments as fun
, but returns the gradient instead. The gradient has the same type and size as the target argument which can be a Number, Array, Tuple, or Dict.
AutoGrad.gradloss
— Function.gradloss(fun, argnum=1)
Another version of grad
where the generated function returns a (gradient,value) pair.
AutoGrad.gradcheck
— Function.gradcheck(f, w, x...; kwargs...)
Numerically check the gradient of f(w,x...;o...)
with respect to its first argument w
and return a boolean result.
The argument w
can be a Number, Array, Tuple or Dict which in turn can contain other Arrays etc. Only the largest 10 entries in each numerical gradient array are checked by default. If the output of f is not a number, gradcheck constructs and checks a scalar function by taking its dot product with a random vector.
Keywords
gcheck=10
: number of largest entries from each numeric array in gradientdw=(grad(f))(w,x...;o...)
compared to their numerical estimates.verbose=false
: print detailed messages if true.kwargs=[]
: keyword arguments to be passed tof
.delta=atol=rtol=cbrt(eps(w))
: tolerance parameters. Seeisapprox
for their meaning.
KnetArray
Knet.KnetArray
— Type.KnetArray{T}(dims)
KnetArray(a::AbstractArray)
Array(k::KnetArray)
Container for GPU arrays that supports most of the AbstractArray interface. The constructor allocates a KnetArray in the currently active device, as specified by gpu()
. KnetArrays and Arrays can be converted to each other as shown above, which involves copying to and from the GPU memory. Only Float32/64 KnetArrays are fully supported.
Important differences from the alternative CudaArray are: (1) a custom memory manager that minimizes the number of calls to the slow cudaMalloc by reusing already allocated but garbage collected GPU pointers. (2) a custom getindex that handles ranges such as a[5:10]
as views with shared memory instead of copies. (3) custom CUDA kernels that implement elementwise, broadcasting, and reduction operations.
Supported functions:
Array operations: ==, !=, cat, convert, copy, copy!, deepcopy, display, eachindex, eltype, endof, fill!, first, getindex, hcat, isapprox, isempty, length, linearindexing, ndims, ones, pointer, rand!, reshape, setindex!, similar, size, stride, strides, summary, vcat, vec, zeros. (Only Integer, Colon, and UnitRange indices supported for get/setindex. CartesianIndex, StepRange, Array, and Bool indices not supported. cat(i,x,y) supported for i=1,2.)
Math operators: (-), abs, abs2, acos, acosh, asin, asinh, atan, atanh, cbrt, ceil, cos, cosh, cospi, erf, erfc, erfcinv, erfcx, erfinv, exp, exp10, exp2, expm1, floor, log, log10, log1p, log2, round, sign, sin, sinh, sinpi, sqrt, tan, tanh, trunc
Broadcasting operators: (.*), (.+), (.-), (./), (.<), (.<=), (.!=), (.==), (.>), (.>=), (.^), max, min. (Only Array-Scalar and Array-Vector broadcasting are supported. Boolean operators generate outputs with same type as inputs; no support for KnetArray{Bool}.)
Reduction operators: countnz, maximum, minimum, prod, sum, sumabs, sumabs2, vecnorm. (Only Array->Scalar and Array->Vector reductions are supported)
Linear algebra: (*), axpy!, permutedims (only 2D and 3D), transpose
Knet extras: relu, sigm, invx, logp, logsumexp, conv4, pool, deconv4, unpool, mat, update! (Only 4D/5D, Float32/64 KnetArrays support conv4, pool, deconv4, unpool)
Memory management
Knet models do not overwrite arrays which need to be preserved for gradient calculation. This leads to a lot of allocation and regular GPU memory allocation is prohibitively slow. Fortunately most models use identically sized arrays over and over again, so we can minimize the number of actual allocations by reusing preallocated but garbage collected pointers.
When Julia gc reclaims a KnetArray, a special finalizer keeps its pointer in a table instead of releasing the memory. If an array with the same size in bytes is later requested, the same pointer is reused. The exact algorithm for allocation is:
Try to find a previously allocated and garbage collected pointer in the current device. (0.5 μs)
If not available, try to allocate a new array using cudaMalloc. (10 μs)
If not successful, try running gc() and see if we get a pointer of the right size. (75 ms, but this should be amortized over all reusable pointers that become available due to the gc)
Finally if all else fails, clean up all saved pointers in the current device using cudaFree and try allocation one last time. (25-70 ms, however this causes the elimination of all reusable pointers)
Utilities
Knet.dir
— Function.Knet.dir(path...)
Construct a path relative to Knet root.
Example
julia> Knet.dir("examples","mnist.jl")
"/home/dyuret/.julia/v0.5/Knet/examples/mnist.jl"
Knet.gpu
— Function.gpu()
returns the id of the active GPU device or -1 if none are active.
gpu(true)
resets all GPU devices and activates the one with the most available memory.
gpu(false)
resets and deactivates all GPU devices.
gpu(d::Int)
activates the GPU device d
if 0 <= d < gpuCount()
, otherwise deactivates devices.
gpu(true/false)
resets all devices. If there are any allocated KnetArrays their pointers will be left dangling. Thus gpu(true/false)
should only be used during startup. If you want to suspend GPU use temporarily, use gpu(-1)
.
gpu(d::Int)
does not reset the devices. You can select a previous device and find allocated memory preserved. However trying to operate on arrays of an inactive device will result in error.
Knet.logp
— Function.logp(x,[dims])
Treat entries in x
as as unnormalized log probabilities and return normalized log probabilities.
dims
is an optional argument, if not specified the normalization is over the whole x
, otherwise the normalization is performed over the given dimensions. In particular, if x
is a matrix, dims=1
normalizes columns of x
and dims=2
normalizes rows of x
.
Knet.logsumexp
— Function.logsumexp(x,[dims])
Compute log(sum(exp(x),dims))
in a numerically stable manner.
dims
is an optional argument, if not specified the summation is over the whole x
, otherwise the summation is performed over the given dimensions. In particular if x
is a matrix, dims=1
sums columns of x
and dims=2
sums rows of x
.
Knet.invx
— Function.invx(x) = (1./x)
Knet.relu
— Function.relu(x) = max(0,x)
Knet.sigm
— Function.sigm(x) = (1./(1+exp(-x)))
Convolution
Knet.conv4
— Function.conv4(w, x; kwargs...)
Execute convolutions or cross-correlations using filters specified with w
over tensor x
.
Currently KnetArray{Float32/64,4/5} and Array{Float32/64,4} are supported as w
and x
. If w
has dimensions (W1,W2,...,I,O)
and x
has dimensions (X1,X2,...,I,N)
, the result y
will have dimensions (Y1,Y2,...,O,N)
where
Yi=1+floor((Xi+2*padding[i]-Wi)/stride[i])
Here I
is the number of input channels, O
is the number of output channels, N
is the number of instances, and Wi,Xi,Yi
are spatial dimensions. padding
and stride
are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.
Keywords
padding=0
: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.stride=1
: the number of elements to slide to reach the next filtering window.upscale=1
: upscale factor for each dimension.mode=0
: 0 for convolution and 1 for cross-correlation.alpha=1
: can be used to scale the result.algo=0
: specifies which convolution algorithm shoud be used to compute the results. See the CUDNN User Guide for details.workSpace=C_NULL
: data pointer to GPU memory to a workspace needed to able to execute the specified algorithm.workSpaceSizeInBytes=0
: the size in bytes of the provided workSpace. Default=0.handle
: handle to a previously created cuDNN context. Defaults to a Knet allocated handle.
Knet.pool
— Function.pool(x; kwargs...)
Compute pooling of input values (i.e., the maximum or average of several adjacent values) to produce an output with smaller height and/or width.
Currently 4 or 5 dimensional KnetArrays with Float32
or Float64
entries are supported. If x
has dimensions (X1,X2,...,I,N)
, the result y
will have dimensions (Y1,Y2,...,I,N)
where
Yi=1+floor((Xi+2*padding[i]-window[i])/stride[i])
Here I
is the number of input channels, N
is the number of instances, and Xi,Yi
are spatial dimensions. window
, padding
and stride
are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.
Keywords:
window=2
: the pooling window size for each dimension.padding=0
: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.stride=window
: the number of elements to slide to reach the next pooling window.mode=0
: 0 for max, 1 for average including padded values, 2 for average excluding padded values.maxpoolingNanOpt=0
: Nan numbers are not propagated if 0, they are propagated if 1.alpha=1
: can be used to scale the result.handle
: Handle to a previously created cuDNN context. Defaults to a Knet allocated handle.
Knet.mat
— Function.mat(x)
Reshape x into a two-dimensional matrix.
This is typically used when turning the output of a 4-D convolution result into a 2-D input for a fully connected layer. For 1-D inputs returns reshape(x, (length(x),1))
. For inputs with more than two dimensions of size (X1,X2,...,XD)
, returns
reshape(x, (X1*X2*...*X[D-1],XD))
Knet.deconv4
— Function.Deconvolution; reverse
of convolution.
Knet.unpool
— Function.Unpooling; reverse
of pooling.
x == pool(unpool(x;o...); o...)
Optimization methods
Knet.update!
— Function.update!(weights, gradients, params)
update!(weights, gradients; lr=0.001)
Update the weights
using their gradients
and the optimization algorithm parameters specified by params
. The 2-arg version defaults to the Sgd
algorithm with learning rate lr
. The weights
and possibly params
are modified in-place.
weights
can be an individual numeric array or a collection of arrays represented by an iterator or dictionary. In the individual case, gradients
should be a similar numeric array of size(weights)
and params
should be a single object. In the collection case, each individual weight array should have a corresponding params object. This way different weight arrays can have their own optimization state, different learning rates, or even different optimization algorithms running in parallel. In the iterator case, gradients
and params
should be iterators of the same length as weights
with corresponding elements. In the dictionary case, gradients
and params
should be dictionaries with the same keys as weights
.
Individual optimization parameters can be one of the following types:
Sgd
(;lr=0.001)
Momentum
(;lr=0.001, gamma=0.9)
Rmsprop
(;lr=0.001, rho=0.9, eps=1e-6)
Adagrad
(;lr=0.1, eps=1e-6)
Adadelta
(;lr=0.01, rho=0.9, eps=1e-6)
Adam
(;lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8)
Example:
w = rand(d) # an individual weight array
g = lossgradient(w) # gradient g has the same shape as w
update!(w, g) # update w in-place with Sgd()
update!(w, g; lr=0.1) # update w in-place with Sgd(lr=0.1)
update!(w, g, Sgd(lr=0.1)) # update w in-place with Sgd(lr=0.1)
w = (rand(d1), rand(d2)) # a tuple of weight arrays
g = lossgradient2(w) # g will also be a tuple
p = (Adam(), Sgd()) # p has params for each w[i]
update!(w, g, p) # update each w[i] in-place with g[i],p[i]
w = Any[rand(d1), rand(d2)] # any iterator can be used
g = lossgradient3(w) # g will be similar to w
p = Any[Adam(), Sgd()] # p should be an iterator of same length
update!(w, g, p) # update each w[i] in-place with g[i],p[i]
w = Dict(:a => rand(d1), :b => rand(d2)) # dictionaries can be used
g = lossgradient4(w)
p = Dict(:a => Adam(), :b => Sgd())
update!(w, g, p)
Knet.Sgd
— Type.Sgd(;lr=0.001)
update!(w,g,p::Sgd)
update!(w,g;lr=0.001)
Container for parameters of the Stochastic gradient descent (SGD) optimization algorithm used by update!
.
SGD is an optimization technique to minimize an objective function by updating its weights in the opposite direction of their gradient. The learning rate (lr) determines the size of the step. SGD updates the weights with the following formula:
w = w - lr * g
where w
is a weight array, g
is the gradient of the loss function w.r.t w
and lr
is the learning rate.
SGD is used by default if no algorithm is specified in the two argument version of update!
[@ref].
Knet.Momentum
— Type.Momentum(;lr=0.001, gamma=0.9)
update(w,g,p::Momentum)
Container for parameters of the Momentum optimization algorithm used by update!
.
The Momentum method tries to accelerate SGD by adding a velocity term to the update. This also decreases the oscillation between successive steps. It updates the weights with the following formulas:
velocity = gamma * velocity + lr * g
w = w - velocity
where w
is a weight array, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, gamma
is the momentum parameter, velocity
is an array with the same size and type of w
and holds the accelerated gradients.
Reference: Qian, N. (1999). On the momentum term in gradient descent learning algorithms. Neural Networks : The Official Journal of the International Neural Network Society, 12(1), 145–151.
Knet.Adagrad
— Type.Adagrad(;lr=0.1, eps=1e-6)
update(w,g,p::Adagrad)
Container for parameters of the Adagrad optimization algorithm used by update!
.
Adagrad is one of the methods that adapts the learning rate to each of the weights. It stores the sum of the squares of the gradients to scale the learning rate. The learning rate is adapted for each weight by the value of current gradient divided by the accumulated gradients. Hence, the learning rate is greater for the parameters where the accumulated gradients are small and the learning rate is small if the accumulated gradients are large. It updates the weights with the following formulas:
G = G + g .^ 2
w = w - g .* lr ./ sqrt(G + eps)
where w
is the weight, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, G
is an array with the same size and type of w
and holds the sum of the squares of the gradients. eps
is a small constant to prevent a zero value in the denominator.
Reference: Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159.
Knet.Adadelta
— Type.Adadelta(;lr=0.01, rho=0.9, eps=1e-6)
update(w,g,p::Adadelta)
Container for parameters of the Adadelta optimization algorithm used by update!
.
Adadelta is an extension of Adagrad that tries to prevent the decrease of the learning rates to zero as training progresses. It scales the learning rate based on the accumulated gradients like Adagrad and holds the acceleration term like Momentum. It updates the weights with the following formulas:
G = (1-rho) * g .^ 2 + rho * G
update = g .* sqrt(delta + eps) ./ sqrt(G + eps)
w = w - lr * update
delta = rho * delta + (1-rho) * update .^ 2
where w
is the weight, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, G
is an array with the same size and type of w
and holds the sum of the squares of the gradients. eps
is a small constant to prevent a zero value in the denominator. rho
is the momentum parameter and delta
is an array with the same size and type of w
and holds the sum of the squared updates.
Reference: Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method.
Knet.Rmsprop
— Type.Rmsprop(;lr=0.001, rho=0.9, eps=1e-6)
update(w,g,p::Rmsprop)
Container for parameters of the Rmsprop optimization algorithm used by update!
.
Rmsprop scales the learning rates by dividing the root mean squared of the gradients. It updates the weights with the following formula:
G = (1-rho) * g .^ 2 + rho * G
w = w - lr * g ./ sqrt(G + eps)
where w
is the weight, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, G
is an array with the same size and type of w
and holds the sum of the squares of the gradients. eps
is a small constant to prevent a zero value in the denominator. rho
is the momentum parameter and delta
is an array with the same size and type of w
and holds the sum of the squared updates.
Reference: Tijmen Tieleman and Geoffrey Hinton (2012). "Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude." COURSERA: Neural Networks for Machine Learning 4.2.
Knet.Adam
— Type.Adam(;lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8)
update(w,g,p::Adam)
Container for parameters of the Adam optimization algorithm used by update!
.
Adam is one of the methods that compute the adaptive learning rate. It stores accumulated gradients (first moment) and the sum of the squared of gradients (second). It scales the first and second moment as a function of time. Here is the update formulas:
m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g .* g
mhat = m ./ (1 - beta1 ^ t)
vhat = v ./ (1 - beta2 ^ t)
w = w - (lr / (sqrt(vhat) + eps)) * mhat
where w
is the weight, g
is the gradient of the objective function w.r.t w
, lr
is the learning rate, m
is an array with the same size and type of w
and holds the accumulated gradients. v
is an array with the same size and type of w
and holds the sum of the squares of the gradients. eps
is a small constant to prevent a zero denominator. beta1
and beta2
are the parameters to calculate bias corrected first and second moments. t
is the update count.
Reference: Kingma, D. P., & Ba, J. L. (2015). Adam: a Method for Stochastic Optimization. International Conference on Learning Representations, 1–13.
Initialization
Knet.gaussian
— Function.gaussian(a...; mean=0.0, std=0.01)
Return a Gaussian array with a given mean and standard deviation. The a
arguments are passed to randn
.
Knet.xavier
— Function.xavier(a...)
Xavier initialization. The a
arguments are passed to rand
. See (Glorot and Bengio 2010) for a description. Caffe implements this slightly differently. Lasagne calls it GlorotUniform
.
Knet.bilinear
— Function.Bilinear interpolation filter weights; used for initializing deconvolution layers.
Adapted from https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py#L33
Arguments:
T
: Data Type
fw
: Width upscale factor
fh
: Height upscale factor
IN
: Number of input filters
ON
: Number of output filters
Example usage:
w = bilinear(Float32,2,2,128,128)
AutoGrad (advanced)
AutoGrad.@primitive
— Macro.@primitive fx g1 g2...
Define a new primitive operation for AutoGrad and (optionally) specify its gradients. Non-differentiable functions such as sign
, and non-numeric functions such as size
should be defined using the @zerograd macro instead.
Examples
@primitive sin(x::Number)
@primitive hypot(x1::Array,x2::Array),dy,y
@primitive sin(x::Number),dy (dy*cos(x))
@primitive hypot(x1::Array,x2::Array),dy,y (dy.*x1./y) (dy.*x2./y)
The first example shows that fx
is a typed method declaration. Julia supports multiple dispatch, i.e. a single function can have multiple methods with different arg types. AutoGrad takes advantage of this and supports multiple dispatch for primitives and gradients.
The second example specifies variable names for the output gradient dy
and the output y
after the method declaration which can be used in gradient expressions. Untyped, ellipsis and keyword arguments are ok as in f(a::Int,b,c...;d=1)
. Parametric methods such as f{T<:Number}(x::T)
cannot be used.
The method declaration can optionally be followed by gradient expressions. The third and fourth examples show how gradients can be specified. Note that the parameters, the return variable and the output gradient of the original function can be used in the gradient expressions.
Under the hood
The @primitive macro turns the first example into:
local sin_r = recorder(sin)
sin{T<:Number}(x::Rec{T}) = sin_r(x)
This will cause calls to sin
with a boxed argument (Rec{T<:Number}
) to be recorded. The recorded operations are used by grad
to construct a dynamic computational graph. With multiple arguments things are a bit more complicated. Here is what happens with the second example:
local hypot_r = recorder(hypot)
hypot{T<:Array,S<:Array}(x1::Rec{T},x2::Rec{S})=hypot_r(x1,x2)
hypot{T<:Array,S<:Array}(x1::Rec{T},x2::S)=hypot_r(x1,x2)
hypot{T<:Array,S<:Array}(x1::T,x2::Rec{S})=hypot_r(x1,x2)
We want the recorder version to be called if any one of the arguments is a boxed Rec
. There is no easy way to specify this in Julia, so the macro generates all 2^N-1 boxed/unboxed argument combinations.
In AutoGrad, gradients are defined using gradient methods that have the following signature:
f(Grad{i},dy,y,x...) => dx[i]
For the third example here is the generated gradient method:
sin{T<:Number}(::Type{Grad{1}}, dy, y, x::Rec{T})=(dy*cos(x))
For the last example a different gradient method is generated for each argument:
hypot{T<:Array,S<:Array}(::Type{Grad{1}},dy,y,x1::Rec{T},x2::Rec{S})=(dy.*x1./y)
hypot{T<:Array,S<:Array}(::Type{Grad{2}},dy,y,x1::Rec{T},x2::Rec{S})=(dy.*x2./y)
In fact @primitive generates four more definitions for the other boxed/unboxed argument combinations.
AutoGrad.@zerograd
— Macro.@zerograd f(args...; kwargs...)
Define f
as an AutoGrad primitive operation with zero gradient.
Example:
@zerograd floor(x::Float32)
@zerograd
allows f
to handle boxed Rec
inputs by unboxing them like a @primitive
, but unlike @primitive
it does not record its actions or return a boxed Rec
result. Some functions, like sign()
, have zero gradient. Others, like length()
have discrete or constant outputs. These need to handle Rec
inputs, but do not need to record anything and can return regular values. Their output can be treated like a constant in the program. Use the @zerograd
macro for those. Note that kwargs
are NOT unboxed.
AutoGrad.getval
— Function.getval(x)
Unbox x
if it is a boxed value (Rec
), otherwise return x
.
Function Index
Knet.Adadelta
Knet.Adagrad
Knet.Adam
Knet.KnetArray
Knet.Momentum
Knet.Rmsprop
Knet.Sgd
AutoGrad.getval
AutoGrad.grad
AutoGrad.gradcheck
AutoGrad.gradloss
Knet.bilinear
Knet.conv4
Knet.deconv4
Knet.dir
Knet.gaussian
Knet.gpu
Knet.invx
Knet.logp
Knet.logsumexp
Knet.mat
Knet.pool
Knet.relu
Knet.sigm
Knet.unpool
Knet.update!
Knet.xavier
AutoGrad.@primitive
AutoGrad.@zerograd