Reference
Contents
AutoGrad
AutoGrad.grad
— Function.Usage:
x = Param([1,2,3]) # user declares parameters
x => P([1,2,3]) # they are wrapped in a struct
value(x) => [1,2,3] # we can get the original value
sum(abs2,x) => 14 # they act like regular values outside of differentiation
y = @diff sum(abs2,x) # if you want the gradients
y => T(14) # you get another struct
value(y) => 14 # which represents the same value
grad(y,x) => [2,4,6] # but also contains gradients for all Params
Param(x)
returns a struct that acts like x
but marks it as a parameter you want to compute gradients with respect to.
@diff expr
evaluates an expression and returns a struct that contains its value (which should be a scalar) and gradient information.
grad(y, x)
returns the gradient of y
(output by @diff) with respect to any parameter x::Param
, or nothing
if the gradient is 0.
value(x)
returns the value associated with x
if x
is a Param
or the output of @diff
, otherwise returns x
.
params(x)
returns an array of Params found by a recursive search of object x
.
Alternative usage:
x = [1 2 3]
f(x) = sum(abs2, x)
f(x) => 14
grad(f)(x) => [2 4 6]
gradloss(f)(x) => ([2 4 6], 14)
Given a scalar valued function f
, grad(f,argnum=1)
returns another function g
which takes the same inputs as f
and returns the gradient of the output with respect to the argnum'th argument. gradloss
is similar except the resulting function also returns f's output.
AutoGrad.gradloss
— Function.Usage:
x = Param([1,2,3]) # user declares parameters
x => P([1,2,3]) # they are wrapped in a struct
value(x) => [1,2,3] # we can get the original value
sum(abs2,x) => 14 # they act like regular values outside of differentiation
y = @diff sum(abs2,x) # if you want the gradients
y => T(14) # you get another struct
value(y) => 14 # which represents the same value
grad(y,x) => [2,4,6] # but also contains gradients for all Params
Param(x)
returns a struct that acts like x
but marks it as a parameter you want to compute gradients with respect to.
@diff expr
evaluates an expression and returns a struct that contains its value (which should be a scalar) and gradient information.
grad(y, x)
returns the gradient of y
(output by @diff) with respect to any parameter x::Param
, or nothing
if the gradient is 0.
value(x)
returns the value associated with x
if x
is a Param
or the output of @diff
, otherwise returns x
.
params(x)
returns an array of Params found by a recursive search of object x
.
Alternative usage:
x = [1 2 3]
f(x) = sum(abs2, x)
f(x) => 14
grad(f)(x) => [2 4 6]
gradloss(f)(x) => ([2 4 6], 14)
Given a scalar valued function f
, grad(f,argnum=1)
returns another function g
which takes the same inputs as f
and returns the gradient of the output with respect to the argnum'th argument. gradloss
is similar except the resulting function also returns f's output.
AutoGrad.gradcheck
— Function.gradcheck(f, x...; kwargs...)
Numerically check the gradient of f(x...)
and return a boolean result.
Each argument can be a Number, Array, Tuple or Dict which in turn can contain other Arrays etc. Only 10 random entries in each large numeric array are checked by default. If the output of f
is not a number, we check the gradient of sum(f(x...))
.
Keywords
args=:
: the argument indices to check gradients with respect to. Could be an array or range of indices or a single index. By default all arguments that have alength
method are checked.kw=()
: keyword arguments to be passed tof
.nsample=10
: number of random entries from each numeric array in gradientdw=(grad(f))(w,x...;o...)
compared to their numerical estimates.atol=rtol=0.01
: tolerance parameters. Seeisapprox
for their meaning.delta=0.0001
: step size for numerical gradient calculation.verbose=1
: 0 prints nothing, 1 shows failing tests, 2 shows all tests.
KnetArray
Knet.KnetArray
— Type.KnetArray{T}(undef,dims)
KnetArray(a::AbstractArray)
Array(k::KnetArray)
Container for GPU arrays that supports most of the AbstractArray interface. The constructor allocates a KnetArray in the currently active device, as specified by gpu()
. KnetArrays and Arrays can be converted to each other as shown above, which involves copying to and from the GPU memory. Only Float32/64 KnetArrays are fully supported.
Important differences from the alternative CudaArray are: (1) a custom memory manager that minimizes the number of calls to the slow cudaMalloc by reusing already allocated but garbage collected GPU pointers. (2) a custom getindex that handles ranges such as a[5:10]
as views with shared memory instead of copies. (3) custom CUDA kernels that implement elementwise, broadcasting, and reduction operations.
Supported functions:
Indexing: getindex, setindex! with the following index types:
- 1-D: Real, Colon, OrdinalRange, AbstractArray{Real}, AbstractArray{Bool}, CartesianIndex, AbstractArray{CartesianIndex}, EmptyArray, KnetArray{Int32} (low level), KnetArray{0/1} (using float for BitArray) (1-D includes linear indexing of multidimensional arrays)
- 2-D: (Colon,Union{Real,Colon,OrdinalRange,AbstractVector{Real},AbstractVector{Bool},KnetVector{Int32}}), (Union{Real,AbstractUnitRange,Colon}...) (in any order)
- N-D: (Real...)
Array operations: ==, !=, cat, convert, copy, copyto!, deepcopy, display, eachindex, eltype, endof, fill!, first, hcat, isapprox, isempty, length, ndims, one, ones, pointer, rand!, randn!, reshape, similar, size, stride, strides, summary, vcat, vec, zero. (cat(x,y,dims=i) supported for i=1,2.)
Math operators: (-), abs, abs2, acos, acosh, asin, asinh, atan, atanh, cbrt, ceil, cos, cosh, cospi, erf, erfc, erfcinv, erfcx, erfinv, exp, exp10, exp2, expm1, floor, log, log10, log1p, log2, round, sign, sin, sinh, sinpi, sqrt, tan, tanh, trunc
Broadcasting operators: (.*), (.+), (.-), (./), (.<), (.<=), (.!=), (.==), (.>), (.>=), (.^), max, min. (Boolean operators generate outputs with same type as inputs; no support for KnetArray{Bool}.)
Reduction operators: countnz, maximum, mean, minimum, prod, sum, sumabs, sumabs2, norm.
Linear algebra: (*), axpy!, permutedims (up to 5D), transpose
Knet extras: relu, sigm, invx, logp, logsumexp, conv4, pool, deconv4, unpool, mat, update! (Only 4D/5D, Float32/64 KnetArrays support conv4, pool, deconv4, unpool)
Memory management
Knet models do not overwrite arrays which need to be preserved for gradient calculation. This leads to a lot of allocation and regular GPU memory allocation is prohibitively slow. Fortunately most models use identically sized arrays over and over again, so we can minimize the number of actual allocations by reusing preallocated but garbage collected pointers.
When Julia gc reclaims a KnetArray, a special finalizer keeps its pointer in a table instead of releasing the memory. If an array with the same size in bytes is later requested, the same pointer is reused. The exact algorithm for allocation is:
Try to find a previously allocated and garbage collected pointer in the current device. (0.5 μs)
If not available, try to allocate a new array using cudaMalloc. (10 μs)
If not successful, try running gc() and see if we get a pointer of the right size. (75 ms, but this should be amortized over all reusable pointers that become available due to the gc)
Finally if all else fails, clean up all saved pointers in the current device using cudaFree and try allocation one last time. (25-70 ms, however this causes the elimination of all reusable pointers)
Utilities
Knet.accuracy
— Function.accuracy(scores, answers; dims=1, average=true)
Given an unnormalized scores
matrix and an Integer
array of correct answers
, return the ratio of instances where the correct answer has the maximum score. dims=1
means instances are in columns, dims=2
means instances are in rows. Use average=false
to return the number of correct answers instead of the ratio.
accuracy(model, data; dims=1, average=true, o...)
Compute accuracy(model(x; o...), y; dims)
for (x,y)
in data
and return the ratio (if average=true) or the count (if average=false) of correct answers.
Knet.dir
— Function.Knet.dir(path...)
Construct a path relative to Knet root.
Example
julia> Knet.dir("examples","mnist.jl")
"/home/dyuret/.julia/v0.5/Knet/examples/mnist.jl"
Knet.dropout
— Function.dropout(x, p)
Given an array x
and probability 0<=p<=1
, just return x
if p==0
, or return an array y
in which each element is 0 with probability p
or x[i]/(1-p)
with probability 1-p
. Use seed::Number
to set the random number seed for reproducible results. See (Srivastava et al. 2014) for a reference.
Knet.gpu
— Function.gpu()
returns the id of the active GPU device or -1 if none are active.
gpu(true)
resets all GPU devices and activates the one with the most available memory.
gpu(false)
resets and deactivates all GPU devices.
gpu(d::Int)
activates the GPU device d
if 0 <= d < gpuCount()
, otherwise deactivates devices.
gpu(true/false)
resets all devices. If there are any allocated KnetArrays their pointers will be left dangling. Thus gpu(true/false)
should only be used during startup. If you want to suspend GPU use temporarily, use gpu(-1)
.
gpu(d::Int)
does not reset the devices. You can select a previous device and find allocated memory preserved. However trying to operate on arrays of an inactive device will result in error.
Knet.invx
— Function.invx(x) = (1./x)
Knet.gc
— Function.Knet.gc(dev=gpu())
cudaFree all pointers allocated on device dev
that were previously allocated and garbage collected. Normally Knet holds on to all garbage collected pointers for reuse. Try this if you run out of GPU memory.
Knet.softmax
— Function.softmax(x; dims=1, algo=1)
The softmax function typically used in classification. Gives the same results as to exp.(logp(x, dims))
.
If algo=1
computation is more accurate, if algo=0
it is faster.
See also logsoftmax
.
Knet.logp
— Function.logp(x; dims=:)
Treat entries in x
as as unnormalized log probabilities and return normalized log probabilities.
dims
is an optional argument, if not specified the normalization is over the whole x
, otherwise the normalization is performed over the given dimensions. In particular, if x
is a matrix, dims=1
normalizes columns of x
and dims=2
normalizes rows of x
.
Knet.logsoftmax
— Function. logsoftmax(x; dims=:)
Equivalent to logp(x; dims=:)
. See also sotfmax
.
Knet.logsumexp
— Function.logsumexp(x;dims=:)
Compute log(sum(exp(x);dims))
in a numerically stable manner.
dims
is an optional argument, if not specified the summation is over the whole x
, otherwise the summation is performed over the given dimensions. In particular if x
is a matrix, dims=1
sums columns of x
and dims=2
sums rows of x
.
Knet.logistic
— Function.logistic(scores, answers; average=true)
Computes logistic loss given scores(predicted values) and answer labels. answer values should be {-1,1}, then it returns mean|sum(log(1 + exp(-answers*scores)))
. See also bce
.
Knet.bce
— Function.bce(scores,answers;average=true)
Computes binary cross entropy given scores(predicted values) and answer labels. answer values should be {0,1}, then it returns negative of mean|sum(answers * log(p) + (1-answers)*log(1-p))
where p
is equal to 1/(1 + exp.(scores))
. See also logistic
.
Knet.minibatch
— Function.minibatch(x, [y], batchsize; shuffle, partial, xtype, ytype, xsize, ysize)
Return an iterable of minibatches [(xi,yi)...] given data tensors x, y and batchsize. The last dimension of x and y should match and give the number of instances. y
is optional. Keyword arguments:
shuffle=false
: Shuffle the instances before minibatching.partial=false
: If true include the last partial minibatch < batchsize.xtype=typeof(x)
: Convert xi in minibatches to this type.ytype=typeof(y)
: Convert yi in minibatches to this type.xsize=size(x)
: Convert xi in minibatches to this shape.ysize=size(y)
: Convert yi in minibatches to this shape.
Knet.nll
— Function.nll(scores, answers; dims=1, average=true)
Given an unnormalized scores
matrix and an Integer
array of correct answers
, return the per-instance negative log likelihood. dims=1
means instances are in columns, dims=2
means instances are in rows. Use average=false
to return the sum instead of per-instance average.
nll(model, data; dims=1, average=true, o...)
Compute nll(model(x; o...), y; dims)
for (x,y)
in data
and return the per-instance average (if average=true) or total (if average=false) negative log likelihood.
Knet.relu
— Function.relu(x)
Return max(0,x)
.
Reference: Rectified Linear Units Improve Restricted Boltzmann Machines (https://icml.cc/Conferences/2010/abstracts.html#432).
Knet.seed!
— Function.Knet.seed!(n::Integer)
Run seed!(n) on both cpu and gpu.
Knet.sigm
— Function.sigm(x) = (1./(1+exp(-x)))
Convolution and Pooling
Knet.conv4
— Function.conv4(w, x; kwargs...)
Execute convolutions or cross-correlations using filters specified with w
over tensor x
.
Currently KnetArray{Float32/64,4/5} and Array{Float32/64,4} are supported as w
and x
. If w
has dimensions (W1,W2,...,I,O)
and x
has dimensions (X1,X2,...,I,N)
, the result y
will have dimensions (Y1,Y2,...,O,N)
where
Yi=1+floor((Xi+2*padding[i]-Wi)/stride[i])
Here I
is the number of input channels, O
is the number of output channels, N
is the number of instances, and Wi,Xi,Yi
are spatial dimensions. padding
and stride
are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.
Keywords
padding=0
: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.stride=1
: the number of elements to slide to reach the next filtering window.upscale=1
: upscale factor for each dimension.mode=0
: 0 for convolution and 1 for cross-correlation.alpha=1
: can be used to scale the result.handle
: handle to a previously created cuDNN context. Defaults to a Knet allocated handle.
Knet.deconv4
— Function.y = deconv4(w, x; kwargs...)
Simulate 4-D deconvolution by using transposed convolution operation. Its forward pass is equivalent to backward pass of a convolution (gradients with respect to input tensor). Likewise, its backward pass (gradients with respect to input tensor) is equivalent to forward pass of a convolution. Since it swaps forward and backward passes of convolution operation, padding and stride options belong to output tensor. See this report for further explanation.
Currently KnetArray{Float32/64,4} and Array{Float32/64,4} are supported as w
and x
. If w
has dimensions (W1,W2,...,O,I)
and x
has dimensions (X1,X2,...,I,N)
, the result y
will have dimensions (Y1,Y2,...,O,N)
where
Yi = Wi+stride[i](Xi-1)-2padding[i]
Here I is the number of input channels, O is the number of output channels, N is the number of instances, and Wi,Xi,Yi are spatial dimensions. padding and stride are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.
Keywords
padding=0
: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.stride=1
: the number of elements to slide to reach the next filtering window.mode=0
: 0 for convolution and 1 for cross-correlation.alpha=1
: can be used to scale the result.handle
: handle to a previously created cuDNN context. Defaults to a Knet allocated handle.
Knet.mat
— Function.mat(x)
Reshape x into a two-dimensional matrix.
This is typically used when turning the output of a 4-D convolution result into a 2-D input for a fully connected layer. For 1-D inputs returns reshape(x, (length(x),1))
. For inputs with more than two dimensions of size (X1,X2,...,XD)
, returns
reshape(x, (X1*X2*...*X[D-1],XD))
Knet.pool
— Function.pool(x; kwargs...)
Compute pooling of input values (i.e., the maximum or average of several adjacent values) to produce an output with smaller height and/or width.
Currently 4 or 5 dimensional KnetArrays with Float32
or Float64
entries are supported. If x
has dimensions (X1,X2,...,I,N)
, the result y
will have dimensions (Y1,Y2,...,I,N)
where
Yi=1+floor((Xi+2*padding[i]-window[i])/stride[i])
Here I
is the number of input channels, N
is the number of instances, and Xi,Yi
are spatial dimensions. window
, padding
and stride
are keyword arguments that can be specified as a single number (in which case they apply to all dimensions), or an array/tuple with entries for each spatial dimension.
Keywords:
window=2
: the pooling window size for each dimension.padding=0
: the number of extra zeros implicitly concatenated at the start and at the end of each dimension.stride=window
: the number of elements to slide to reach the next pooling window.mode=0
: 0 for max, 1 for average including padded values, 2 for average excluding padded values.maxpoolingNanOpt=0
: Nan numbers are not propagated if 0, they are propagated if 1.alpha=1
: can be used to scale the result.handle
: Handle to a previously created cuDNN context. Defaults to a Knet allocated handle.
Knet.unpool
— Function.Unpooling; reverse
of pooling.
x == pool(unpool(x;o...); o...)
Recurrent neural networks
Knet.rnninit
Knet.rnnforw
Knet.rnnparam
Knet.rnnparams
Batch Normalization
Knet.bnmoments
— Function.bnmoments(;momentum=0.1, mean=nothing, var=nothing, meaninit=zeros, varinit=ones)
can be used directly load moments from data. meaninit
and varinit
are called if mean
and var
are nothing. Type and size of the mean
and var
are determined automatically from the inputs in the batchnorm
calls. A BNMoments
object is returned.
BNMoments
A high-level data structure used to store running mean and running variance of batch normalization with the following fields:
momentum::AbstractFloat
: A real number between 0 and 1 to be used as the scale of last mean and variance. The existing running mean or variance is multiplied by (1-momentum).
mean
: The running mean.
var
: The running variance.
meaninit
: The function used for initialize the running mean. Should either be nothing
or of the form (eltype, dims...)->data
. zeros
is a good option.
varinit
: The function used for initialize the running variance. Should either be nothing
or (eltype, dims...)->data
. ones
is a good option.
Knet.bnparams
— Function.bnparams(etype, channels)
creates a single 1d array that contains both scale and bias of batchnorm, where the first half is scale and the second half is bias.
bnparams(channels)
calls bnparams
with etype=Float64
, following Julia convention
Knet.batchnorm
— Function.batchnorm(x[, moments, params]; kwargs...)
performs batch normalization to x
with optional scaling factor and bias stored in params
.
2d, 4d and 5d inputs are supported. Mean and variance are computed over dimensions (2,), (1,2,4) and (1,2,3,5) for 2d, 4d and 5d arrays, respectively.
moments
stores running mean and variance to be used in testing. It is optional in the training mode, but mandatory in the test mode. Training and test modes are controlled by the training
keyword argument.
params
stores the optional affine parameters gamma and beta. bnparams
function can be used to initialize params
.
Example
# Inilization, C is an integer
moments = bnmoments()
params = bnparams(C)
...
# size(x) -> (H, W, C, N)
y = batchnorm(x, moments, params)
# size(y) -> (H, W, C, N)
Keywords
eps=1e-5
: The epsilon parameter added to the variance to avoid division by 0.
training
: When training
is true, the mean and variance of x
are used and moments
argument is modified if it is provided. When training
is false, mean and variance stored in the moments
argument are used. Default value is true
when at least one of x
and params
is AutoGrad.Value
, false
otherwise.
Optimization methods
Knet.update!
Knet.optimizers
Knet.Adadelta
Knet.Adagrad
Knet.Adam
Knet.Momentum
Knet.Nesterov
Knet.Rmsprop
Knet.Sgd
Hyperparameter optimization
Knet.goldensection
— Function.goldensection(f,n;kwargs) => (fmin,xmin)
Find the minimum of f
using concurrent golden section search in n
dimensions. See Knet.goldensection_demo()
for an example.
f
is a function from a Vector{Float64}
of length n
to a Number
. It can return NaN
for out of range inputs. Goldensection will always start with a zero vector as the initial input to f
, and the initial step size will be 1 in each dimension. The user should define f
to scale and shift this input range into a vector meaningful for their application. For positive inputs like learning rate or hidden size, you can use a transformation such as x0*exp(x)
where x
is a value goldensection
passes to f
and x0
is your initial guess for this value. This will effectively start the search at x0
, then move with multiplicative steps.
I designed this algorithm combining ideas from Golden Section Search and Hill Climbing Search. It essentially runs golden section search concurrently in each dimension, picking the next step based on estimated gain.
Keyword arguments
dxmin=0.1
: smallest step size.accel=φ
: acceleration rate. Golden ratioφ=1.618...
is best.verbose=false
: usetrue
to print individual steps.history=[]
: cache of[(x,f(x)),...]
function evaluations.
Knet.hyperband
— Function.hyperband(getconfig, getloss, maxresource=27, reduction=3)
Hyperparameter optimization using the hyperband algorithm from (Lisha et al. 2016). You can try a simple MNIST example using Knet.hyperband_demo()
.
Arguments
getconfig()
returns random configurations with a user defined type and distribution.getloss(c,n)
returns loss for configurationc
and number of resources (e.g. epochs)n
.maxresource
is the maximum number of resources any one configuration should be given.reduction
is an algorithm parameter (see paper), 3 is a good value.
Initialization
Knet.bilinear
— Function.Bilinear interpolation filter weights; used for initializing deconvolution layers.
Adapted from https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py#L33
Arguments:
T
: Data Type
fw
: Width upscale factor
fh
: Height upscale factor
IN
: Number of input filters
ON
: Number of output filters
Example usage:
w = bilinear(Float32,2,2,128,128)
Knet.gaussian
— Function.gaussian(a...; mean=0.0, std=0.01)
Return a Gaussian array with a given mean and standard deviation. The a
arguments are passed to randn
.
Knet.xavier
— Function.xavier(a...)
Xavier initialization. The a
arguments are passed to rand
. See (Glorot and Bengio 2010) for a description. Caffe implements this slightly differently. Lasagne calls it GlorotUniform
.
AutoGrad (advanced)
AutoGrad.getval
AutoGrad.@primitive
AutoGrad.@zerograd
Function Index
Knet.Adadelta
Knet.Adagrad
Knet.Adam
Knet.KnetArray
Knet.Momentum
Knet.Nesterov
Knet.Rmsprop
AutoGrad.grad
AutoGrad.gradcheck
AutoGrad.gradloss
Knet.accuracy
Knet.batchnorm
Knet.bce
Knet.bilinear
Knet.bnmoments
Knet.bnparams
Knet.conv4
Knet.deconv4
Knet.dir
Knet.dropout
Knet.gaussian
Knet.gc
Knet.goldensection
Knet.gpu
Knet.hyperband
Knet.invx
Knet.logistic
Knet.logp
Knet.logsoftmax
Knet.logsumexp
Knet.mat
Knet.minibatch
Knet.nll
Knet.optimizers
Knet.pool
Knet.relu
Knet.rnnparam
Knet.rnnparams
Knet.seed!
Knet.sigm
Knet.softmax
Knet.unpool
Knet.update!
Knet.xavier
AutoGrad.@primitive
AutoGrad.@zerograd