来源:互联网 发布:电信宽带加端口要多久 编辑:程序博客网 时间:2024/05/09 20:45

MXNet - Python API

>>> import mxnet as mx



This document lists the routines of the n-dimensional array package

类 说明 mxnet.ndarray NDArray API of MXNet.

A NDArray is a multidimensional container of items of the same type and size. Various methods for data manipulation and computation are provided.

>>> x = mx.nd.array([[1, 2, 3], [4, 5, 6]])>>> type(x)<class 'mxnet.ndarray.NDArray'>>>> x.shape(2, 3)>>> y = x + mx.nd.ones(x.shape)*3>>> print(y.asnumpy())[[ 4.  5.  6.] [ 7.  8.  9.]]>>> z = y.as_in_context(mx.gpu(0))>>> print(z)<NDArray 2x3 @gpu(0)>

A detailed tutorial is available at


mxnet.ndarray is similar to numpy.ndarray in some aspects. But the difference is not negligible. For example

  • NDArray.T does real data transpose to return new a copied array, instead of returning a view of the input array.
  • performs dot between the last axis of the first input array and the first axis of the second input, while uses the second last axis of the input array.

In additional, NDArray supports GPU computation and various neural network layers.


ndarray also provides almost same routines to symbol. Most routines between these two packages share the same C++ operator source codes. But ndarray differs to symbol in several aspects:

  • ndarray adopts imperative programming, namely sentences are executed step-by-step so that the results can be obtained immediately.
  • Most binary operators such as + and > are enabled broadcasting in default.

In the rest of this document, we first overview the methods provided by the ndarray.NDArray class, and then list other routines provided by the ndarray package.

The NDArray class

Array attributes

属性 说明 NDArray.shape Tuple of array dimensions. NDArray.size Number of elements in the array. NDArray.context Device context of the array. NDArray.dtype Data-type of the array’s elements.

Array conversion

变换 说明 NDArray.copy Makes a copy of this NDArray, keeping the same context. NDArray.copyto Copies the value of this array to another array. NDArray.as_in_context Returns an array on the target device with the same value as this array. NDArray.asnumpy Returns a numpy.ndarray object with value copied from this array. NDArray.asscalar Returns a scalar whose value is copied from this array. NDArray.astype Returns a copy of the array after casting to a specified type.

Array change shape

改变维数 说明 NDArray.T Returns a copy of the array with axes transposed. NDArray.reshape Returns a view of this array with a new shape without altering any data. NDArray.broadcast_to Broadcasts the input array to a new shape.

Arithmetic operations

算术操作 说明 NDArray.__add__ x.__add__(y) <=> x+y <=> mx.nd.add(x, y) NDArray.__sub__ x.__sub__(y) <=> x-y <=> mx.nd.subtract(x, y) NDArray.__rsub__ x.__rsub__(y) <=> y-x <=> mx.nd.subtract(y, x) NDArray.__neg__ x.__neg__(y) <=> -x NDArray.__mul__ x.__mul__(y) <=> x*y <=> mx.nd.multiply(x, y) NDArray.__div__ x.__div__(y) <=> x/y <=> mx.nd.divide(x, y) NDArray.__rdiv__ x.__rdiv__(y) <=> y/x <=> mx.nd.divide(y, x) NDArray.__pow__ x.__pow__(y) <=> x**y <=> mx.nd.power(x,y)

In-place arithmetic operations

本地算术操作 说明 NDArray.__iadd__ x.__iadd__(y) <=> x+=y NDArray.__isub__ x.__isub__(y) <=> x-=y NDArray.__imul__ x.__imul__(y) <=> x*=y NDArray.__idiv__ x.__rdiv__(y) <=> x/=y

Comparison operators

比较操作 说明 NDArray.__lt__ x.__lt__(y) <=> x mx.nd.lesser(x, y) NDArray.__le__ x.__le__(y) <=> x<=y <=> mx.nd.less_equal(x, y) NDArray.__gt__ x.__gt__(y) <=> x>y <=> mx.nd.greater(x, y) NDArray.__ge__ x.__ge__(y) <=> x>=y <=> mx.nd.greater_equal(x, y) NDArray.__eq__ x.__eq__(y) <=> x==y <=> mx.nd.equal(x, y) NDArray.__ne__ x.__ne__(y) <=> x!=y <=> mx.nd.not_equal(x, y)


索引 说明 NDArray.__getitem__ x.__getitem__(i) <=> x[i] NDArray.__setitem__ x.__setitem__(i, y) <=> x[i]=y

Lazy evaluation

懒惰(??) 说明 NDArray.wait_to_read Waits until all previous write operations on the current array are finished.

Array creation routines

创建程序 说明 array Creates an array from any object exposing the array interface. empty Returns a new array of given shape and type, without initializing entries. zeros Returns a new array filled with all zeros, with the given shape and type. ones Returns a new array filled with all ones, with the given shape and type. full Returns a new array of given shape and type, filled with the given value val. arange Returns evenly spaced values within a given interval. load Loads an array from file. save Saves a list of arrays or a dict of str->array to file.

Array manipulation routines

Changing array shape and type

更改类型与形状 说明 cast Casts all elements of the input to a new type. reshape Reshapes the input array. flatten Flattens the input array into a 2-D array by collapsing the higher dimensions. expand_dims Inserts a new axis of size 1 into the array shape

Expanding array elements

扩展元素 说明 broadcast_to Broadcasts the input array to a new shape. broadcast_axes Broadcasts the input array over particular axes. repeat Repeats elements of an array. tile Repeats the whole array multiple times. pad Pads an array.

Rearranging elements

元素重排 说明 transpose Permutes the dimensions of an array. swapaxes Interchanges two axes of an array. flip Reverses the order of elements along given axis while preserving array shape.

Joining and splitting arrays

连接与切片 说明 concat Joins input arrays along a given axis. split Splits an array along a particular axis into multiple sub-arrays.

Indexing routines

索引程序 说明 slice Slices a contiguous region of the array. slice_axis Slices along a given axis. take Takes elements from an input array along the given axis. batch_take Takes elements from a data batch. one_hot Returns a one-hot array. pick Picks elements from an input array according to the input indices along the given axis.

Mathematical functions

Arithmetic operations

算术操作 说明 add Returns element-wise sum of the input arrays with broadcasting. subtract Returns element-wise difference of the input arrays with broadcasting. negative Numerical negative, element-wise. multiply Returns element-wise product of the input arrays with broadcasting. divide Returns element-wise division of the input arrays with broadcasting. dot Dot product of two arrays. batch_dot Batchwise dot product. add_n Adds all input arguments element-wise.

Trigonometric functions

三角函数 说明 sin Computes the element-wise sine of the input array. cos Computes the element-wise cosine of the input array. tan Computes the element-wise tangent of the input array. arcsin Returns element-wise inverse sine of the input array. arccos Returns element-wise inverse cosine of the input array. arctan Returns element-wise inverse tangent of the input array. degrees Converts each element of the input array from radians to degrees. radians Converts each element of the input array from degrees to radians.

Hyperbolic functions

双曲函数 说明 sinh Returns the hyperbolic sine of the input array, computed element-wise. cosh Returns the hyperbolic cosine of the input array, computed element-wise. tanh Returns the hyperbolic tangent of the input array, computed element-wise. arcsinh Returns the element-wise inverse hyperbolic sine of the input array, computed element-wise. arccosh Returns the element-wise inverse hyperbolic cosine of the input array, computed element-wise. arctanh Returns the element-wise inverse hyperbolic tangent of the input array, computed element-wise.

Reduce functions

减少函数 说明 sum Computes the sum of array elements over given axes. nansum Computes the sum of array elements over given axes treating Not a Numbers (NaN) as zero. prod Computes the product of array elements over given axes. nanprod Computes the product of array elements over given axes treating Not a Numbers (NaN) as one. mean Computes the mean of array elements over given axes. max Computes the max of array elements over given axes. min Computes the min of array elements over given axes. norm Flattens the input array and then computes the l2 norm.


四舍五入 说明 round Returns element-wise rounded value to the nearest integer of the input. rint Returns element-wise rounded value to the nearest integer of the input. fix Returns element-wise rounded value to the nearest integer towards zero of the input. floor Returns element-wise floor of the input. ceil Returns element-wise ceiling of the input.

Exponents and logarithms

指数与对数 说明 exp Returns element-wise exponential value of the input. expm1 Returns exp(x) - 1 computed element-wise on the input. log Returns element-wise Natural logarithmic value of the input. log10 Returns element-wise Base-10 logarithmic value of the input. log2 Returns element-wise Base-2 logarithmic value of the input. log1p Returns element-wise log(1 + x) value of the input.


幂 说明 power Returns result of first array elements raised to powers from second array, element-wise with broadcasting. sqrt Returns element-wise square-root value of the input. rsqrt Returns element-wise inverse square-root value of the input. square Returns element-wise squared value of the input.

Logic functions

逻辑函数 说明 equal Returns the result of element-wise equal to (==) comparison operation with broadcasting. not_equal Returns the result of element-wise not equal to (!=) comparison operation with broadcasting. greater Returns the result of element-wise greater than (>) comparison operation with broadcasting. greater_equal Returns the result of element-wise greater than or equal to (>=) comparison operation with broadcasting. lesser Returns the result of element-wise lesser than (<) comparison operation with broadcasting. lesser_equal Returns the result of element-wise lesser than or equal to (<=) comparison operation with broadcasting.

Random sampling

随机样例 说明 uniform Draw samples from a uniform distribution. normal Draw random samples from a normal (Gaussian) distribution. mxnet.random.seed Seeds the random number generators in MXNet.

Sorting and searching

排序与搜索 说明 sort Returns a sorted copy of an input array along the given axis. topk Returns the top k elements in an input array along the given axis. argsort Returns the indices that would sort an input array along the given axis. argmax Returns indices of the maximum values along an axis. argmin Returns indices of the minimum values along an axis.


杂项 说明 maximum Returns element-wise maximum of the input arrays with broadcasting. minimum Returns element-wise minimum of the input arrays with broadcasting. clip Clips (limits) the values in an array. abs Returns element-wise absolute value of the input. sign Returns element-wise sign of the input. gamma Returns the gamma function (extension of the factorial function to the reals) , computed element-wise on the input array. gammaln Returns element-wise log of the absolute value of the gamma function of the input.

Neural network


基本 说明 FullyConnected Applies a linear transformation: Y=XWT+bY=XWT+b. Convolution Compute N-D convolution on (N+2)-D input. Activation Applies an activation function element-wise to the input. BatchNorm Batch normalization. Pooling Performs pooling on the input. SoftmaxOutput Computes softmax with logit loss. softmax Applies the softmax function. log_softmax Computes the log softmax of the input.


更多 说明 Correlation Applies correlation to inputs. Deconvolution Applies deconvolution to input and adds a bias. RNN Applies a recurrent layer to input. Embedding Maps integer indices to vector representations (embeddings). LeakyReLU Applies leaky ReLU activation element-wise to the input. InstanceNorm An operator taking in a n-dimensional input tensor (n > 2), and normalizing the input by subtracting the mean and variance calculated over the spatial dimensions. L2Normalization Normalize the input array using the L2 norm. LRN Applies convolution to input and then adds a bias. ROIPooling Performs region of interest(ROI) pooling on the input array. SoftmaxActivation Applies softmax activation to input. Dropout Applies dropout to input. BilinearSampler Applies bilinear sampling to input feature map, which is the key of “[NIPS2015] Spatial Transformer Networks” output[batch, channel, y_dst, x_dst] = G(data[batch, channel, y_src, x_src) x_dst, y_dst enumerate all spatial locations in output x_src = grid[batch, 0, y_dst, x_dst] y_src = grid[batch, 1, y_dst, x_dst] G() denotes the bilinear interpolation kernel The out-boundary points will be padded as zeros. GridGenerator Generates sampling grid for bilinear sampling. UpSampling Performs nearest neighbor/bilinear up sampling to inputs SpatialTransformer Applies a spatial transformer to input feature map. LinearRegressionOutput Computes and optimizes for squared loss. LogisticRegressionOutput Applies a logistic function to the input. MAERegressionOutput Computes mean absolute error of the input. SVMOutput Computes support vector machine based transformation of the input. softmax_cross_entropy Calculate cross_entropy(data, one_hot(label)) smooth_l1 Calculate Smooth L1 Loss(lhs, scalar) IdentityAttachKLSparseReg Apply a sparse regularization to the output a sigmoid activation function. MakeLoss Make your own loss function in network construction. BlockGrad Stops gradient computation. Custom Apply a custom operator implemented in a frontend language (like Python).

Symbol API


This document lists the routines of the symbolic expression package:

类 说明 mxnet.symbol Symbolic configuration API of MXNet.

A symbol declares computation. It is composited by operators, such as simple matrix operations (e.g. “+”), or a neural network layer (e.g. convolution layer). We can bind data to a symbol to execute the computation.

>>> a = mx.sym.Variable('a')>>> b = mx.sym.Variable('b')>>> c = 2 * a + b>>> type(c)<class 'mxnet.symbol.Symbol'>>>> e = c.bind(mx.cpu(), {'a': mx.nd.array([1,2]), 'b':mx.nd.array([2,3])})>>> y = e.forward()>>> y[<NDArray 2 @cpu(0)>]>>> y[0].asnumpy()array([ 4.,  7.], dtype=float32)

A detailed tutorial is available at


most operators provided in symbol are similar to ndarray. But also note that symbol differs to ndarray in several aspects:

  • symbol adopts declare programming. In other words, we need to first composite the computations, and then feed with data to execute.
  • Most binary operators such as + and > are not enabled broadcasting. We need to call the broadcasted version such as broadcast_plus explicitly.

In the rest of this document, we first overview the methods provided by the symbol.Symbol class, and then list other routines provided by the symbol package.

The Symbol class


Composite multiple symbols into a new one by an operator.

组成 说明 Symbol.__call__ Compose symbol on inputs.

Arithmetic operations

算术操作 说明 Symbol.__add__ x.__add__(y) <=> x+y Symbol.__sub__ x.__sub__(y) <=> x-y Symbol.__rsub__ x.__rsub__(y) <=> y-x Symbol.__neg__ x.__neg__(y) <=> -x Symbol.__mul__ x.__mul__(y) <=> x*y Symbol.__div__ x.__div__(y) <=> x/y Symbol.__rdiv__ x.__rdiv__(y) <=> y/x Symbol.__pow__ x.__pow__(y) <=> x**y

Comparison operators

比较操作 说明 Symbol.__lt__ x.__lt__(y) <=> x Symbol.__le__ x.__le__(y) <=> x<=y Symbol.__gt__ x.__gt__(y) <=> x>y Symbol.__ge__ x.__ge__(y) <=> x>=y Symbol.__eq__ x.__eq__(y) <=> x==y Symbol.__ne__ x.__ne__(y) <=> x!=y

Query information

查询信息 说明 Get name string from the symbol, this function only works for non-grouped symbol. Symbol.list_arguments Lists all the arguments in the symbol. Symbol.list_outputs Lists all the outputs in the symbol. Symbol.list_auxiliary_states Lists all the auxiliary states in the symbol. Symbol.list_attr Gets all attributes from the symbol. Symbol.attr Gets attribute string from the symbol. Symbol.attr_dict Recursively gets all attributes from the symbol and its children.

Get internal and output symbol

获取内部和输出符号 说明 Symbol.__getitem__ x.__getitem__(i) <=> x[i] Symbol.__iter__ Returns all outputs in a list Symbol.get_internals Gets a new grouped symbol sgroup. Symbol.get_children Gets a new grouped symbol whose output contains inputs to output nodes of the original symbol.

Inference type and shape

推理类型与形状 说明 Symbol.infer_type Infers the type of all arguments and all outputs, given the known types for some arguments. Symbol.infer_shape Infers the shapes of all arguments and all outputs given the known shapes of some arguments. Symbol.infer_shape_partial Infers the shape partially.


绑定 说明 Symbol.bind Bind current symbol to get an executor. Symbol.simple_bind Bind current symbol to get an executor, allocate all the ndarrays needed.


保存 说明 Saves symbol to a file. Symbol.tojson Saves symbol to a JSON string. Symbol.debug_str Gets a debug string.

Symbol creation routines

符号创建程序 说明 var Create a symbolic variable with specified name. zeros Return a new symbol of given shape and type, filled with zeros. ones Return a new symbol of given shape and type, filled with ones. arange Return evenly spaced values within a given interval.

Symbol manipulation routines

Changing shape and type

改变形状与类型 说明 cast Casts all elements of the input to the new type. reshape Reshapes the input array into a new shape. flatten Flattens the input array into a 2-D array by collapsing the higher dimensions. expand_dims Insert a new axis with size 1 into the array shape

Expanding elements

扩展元素 说明 broadcast_to Broadcasts the input array to a new shape. broadcast_axes Broadcasts the input array over particular axes. repeat Repeat elements of an array. tile Repeat the whole array by multiple times. pad Pad an array.

Rearranging elements

元素重排 说明 transpose Permute the dimensions of an array. swapaxes Interchange two axes of an array. flip Reverse elements of an array with axis

Joining and splitting symbols

符号式连接与切片 说明 concat Join input arrays along the given axis. split Split an array along a particular axis into multiple sub-arrays.

Indexing routines

索引程序 说明 slice Slice a continuous region of the array. slice_axis Slice along a given axis. take Takes elements from an input array along the given axis. batch_take Takes elements from a data batch. one_hot Returns a one-hot array.

Mathematical functions

Arithmetic operations

算术操作 说明 broadcast_add Returns element-wise sum of the input arrays with broadcasting. broadcast_sub Returns element-wise difference of the input arrays with broadcasting. broadcast_mul Returns element-wise product of the input arrays with broadcasting. broadcast_div Returns element-wise division of the input arrays with broadcasting. negative Negate src dot Dot product of two arrays. batch_dot Batchwise dot product. add_n Add all input arguments element-wise.

Trigonometric functions

三角函数 说明 sin Computes the element-wise sine of the input. cos Computes the element-wise cosine of the input array. tan Computes the element-wise tangent of the input array. arcsin Returns element-wise inverse sine of the input array. arccos Returns element-wise inverse cosine of the input array. arctan Returns element-wise inverse tangent of the input array. hypot minimum left and right broadcast_hypot Returns the hypotenuse of a right angled triangle, given its “legs” with broadcasting. degrees Converts each element of the input array from radians to degrees. radians Converts each element of the input array from degrees to radians.

Hyperbolic functions

双曲函数 说明 sinh Returns the hyperbolic sine of the input array, computed element-wise. cosh Returns the hyperbolic cosine of the input array, computed element-wise. tanh Returns the hyperbolic tangent of the input array, computed element-wise. arcsinh Returns the element-wise inverse hyperbolic sine of the input array, computed element-wise. arccosh Returns the element-wise inverse hyperbolic cosine of the input array, computed element-wise. arctanh Returns the element-wise inverse hyperbolic tangent of the input array, computed element-wise.

Reduce functions

减少函数 说明 sum Compute the sum of array elements over given axes. nansum Compute the sum of array elements over given axes with NaN ignored prod Compute the product of array elements over given axes. nanprod Compute the product of array elements over given axes with NaN ignored mean Compute the mean of array elements over given axes. max Compute the max of array elements over given axes. min Compute the min of array elements over given axes. norm Computes the L2 norm of the input array.


四舍五入 说明 round Returns element-wise rounded value to the nearest integer of the input. rint Returns element-wise rounded value to the nearest integer of the input. fix Returns element-wise rounded value to the nearest integer towards zero of the input. floor Returns element-wise floor of the input. ceil Returns element-wise ceiling of the input.

Exponents and logarithms

指数与对数 说明 exp Returns element-wise exponential value of the input. expm1 Returns exp(x) - 1 computed element-wise on the input. log Returns element-wise Natural logarithmic value of the input. log10 Returns element-wise Base-10 logarithmic value of the input. log2 Returns element-wise Base-2 logarithmic value of the input. log1p Returns element-wise log(1 + x) value of the input.


幂 说明 broadcast_power Returns result of first array elements raised to powers from second array, element-wise with broadcasting. sqrt Returns element-wise square-root value of the input. rsqrt Returns element-wise inverse square-root value of the input. square Returns element-wise squared value of the input.

Logic functions

逻辑函数 说明 broadcast_equal Returns the result of element-wise equal to (==) comparison operation with broadcasting. broadcast_not_equal Returns the result of element-wise not equal to (!=) comparison operation with broadcasting. broadcast_greater Returns the result of element-wise greater than (>) comparison operation with broadcasting. broadcast_greater_equal Returns the result of element-wise greater than or equal to (>=) comparison operation with broadcasting. broadcast_lesser Returns the result of element-wise lesser than (<) comparison operation with broadcasting. broadcast_lesser_equal Returns the result of element-wise lesser than or equal to (<=) comparison operation with broadcasting.

Random sampling

随机抽样 说明 uniform Draw samples from a uniform distribution. normal Draw random samples from a normal (Gaussian) distribution. mxnet.random.seed Seed the random number generators in MXNet.

Sorting and searching

排序与搜索 说明 sort Returns a sorted copy of an input array along the given axis. topk Returns the top k elements in an input array along the given axis. argsort Returns the indices that would sort an input array along the given axis. argmax Returns indices of the maximum values along an axis. argmin Returns indices of the minimum values along an axis.


杂项 说明 maximum maximum left and right minimum minimum left and right broadcast_maximum Returns element-wise maximum of the input arrays with broadcasting. broadcast_minimum Returns element-wise minimum of the input arrays with broadcasting. clip Clip (limit) the values in an array. abs Returns element-wise absolute value of the input. sign Returns element-wise sign of the input. gamma Returns the gamma function (extension of the factorial function to the reals) , computed element-wise on the input array. gammaln Returns element-wise log of the absolute value of the gamma function of the input.

Neural network


基本 说明 FullyConnected Apply a linear transformation: Y=XWT+bY=XWT+b. Convolution Compute N-D convolution on (N+2)-D input. Activation Elementwise activation function. BatchNorm Batch normalization. Pooling Perform pooling on the input. SoftmaxOutput Softmax with logit loss. softmax Applies the softmax function. log_softmax Compute the log softmax of the input.


更多 说明 Correlation Apply correlation to inputs Deconvolution Apply deconvolution to input then add a bias. RNN Apply a recurrent layer to input. Embedding Maps integer indices to vector representations (embeddings). LeakyReLU Leaky ReLu activation InstanceNorm An operator taking in a n-dimensional input tensor (n > 2), and normalizing the input by subtracting the mean and variance calculated over the spatial dimensions. L2Normalization Set the l2 norm of each instance to a constant. LRN Apply convolution to input then add a bias. ROIPooling Performs region of interest(ROI) pooling on the input array. SoftmaxActivation Apply softmax activation to input. Dropout Apply dropout to input. BilinearSampler Apply bilinear sampling to input feature map, which is the key of “[NIPS2015] Spatial Transformer Networks” output[batch, channel, y_dst, x_dst] = G(data[batch, channel, y_src, x_src) x_dst, y_dst enumerate all spatial locations in output x_src = grid[batch, 0, y_dst, x_dst] y_src = grid[batch, 1, y_dst, x_dst] G() denotes the bilinear interpolation kernel The out-boundary points will be padded as zeros. GridGenerator generate sampling grid for bilinear sampling. UpSampling Perform nearest neighboor/bilinear up sampling to inputs This function support variable length of positional input. SpatialTransformer Apply spatial transformer to input feature map. LinearRegressionOutput LinearRegressionOutput computes and optimizes for squared loss. LogisticRegressionOutput LogisticRegressionOutput applies a logistic function to the input. MAERegressionOutput MAERegressionOutput function computes mean absolute error. SVMOutput Computes support vector machine based transformation of the input. softmax_cross_entropy Calculate cross_entropy(data, one_hot(label)) smooth_l1 Calculate Smooth L1 Loss(lhs, scalar) IdentityAttachKLSparseReg Apply a sparse regularization to the output a sigmoid activation function. MakeLoss Get output from a symbol and pass 1 gradient back. BlockGrad Get output from a symbol and pass 0 gradient back Custom Custom operator implemented in frontend.

Module API


The module API, defined in the module (or simply mod) package, provides an intermediate and high-level interface for performing computation with a Symbol. One can roughly think a module is a machine which can execute a program defined by a Symbol.

The class module.Module is a commonly used module, which accepts a Symbol as the input:

data = mx.symbol.Variable('data')fc1  = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")fc2  = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=10)out  = mx.symbol.SoftmaxOutput(fc2, name = 'softmax')mod = mx.mod.Module(out)  # create a module by given a Symbol

Assume there is a valid MXNet data iterator data. We can initialize the module:

mod.bind(data_shapes=data.provide_data,         label_shapes=data.provide_label)  # create memory by given input shapesmod.init_params()  # initial parameters with the default random initializer

Now the module is able to compute. We can call high-level API to train and predict:, num_epoch=10, ...)  # trainmod.predict(new_data)  # predict on new data

or use intermediate APIs to perform step-by-step computations

mod.forward(data_batch)  # forward on the provided data batchmod.backward()  # backward to calculate the gradientsmod.update()  # update parameters using the default optimizer

A detailed tutorial is available at


module is used to replace model, which has been deprecated.

The module package provides several modules:

概述 说明 BaseModule The base class of a module. Module Module is a basic module that wrap a Symbol. SequentialModule A SequentialModule is a container module that can chain multiple modules together. BucketingModule This module helps to deal efficiently with varying-length inputs. PythonModule A convenient module class that implements many of the module APIs as empty functions. PythonLossModule A convenient module class that implements many of the module APIs as empty functions.

We summarize the interface for each class in the following sections.

The BaseModule class

The BaseModule is the base class for all other module classes. It defines the interface each module class should provide.

Initialize memory

初始化内存 说明 BaseModule.bind Bind the symbols to construct executors.

Get and set parameters

获取设置参数 说明 BaseModule.init_params Initialize the parameters and auxiliary states. BaseModule.set_params Assign parameter and aux state values. BaseModule.get_params Get parameters, those are potentially copies of the the actual parameters used to do computation on the device. BaseModule.save_params Save model parameters to file. BaseModule.load_params Load model parameters from file.

Train and predict

训练和预测 说明 Train the module parameters. BaseModule.score Run prediction on eval_data and evaluate the performance according to eval_metric. BaseModule.iter_predict Iterate over predictions. BaseModule.predict Run prediction and collect the outputs.

Forward and backward

向前向后 说明 BaseModule.forward Forward computation. BaseModule.backward Backward computation. BaseModule.forward_backward A convenient function that calls both forward and backward.

Update parameters

更新参数 说明 BaseModule.init_optimizer Install and initialize optimizers. BaseModule.update Update parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. BaseModule.update_metric Evaluate and accumulate evaluation metric on outputs of the last forward computation.

Input and output

输入与输出 说明 BaseModule.data_names A list of names for data required by this module. BaseModule.output_names A list of names for the outputs of this module. BaseModule.data_shapes A list of (name, shape) pairs specifying the data inputs to this module. BaseModule.label_shapes A list of (name, shape) pairs specifying the label inputs to this module. BaseModule.output_shapes A list of (name, shape) pairs specifying the outputs of this module. BaseModule.get_outputs Get outputs of the previous forward computation. BaseModule.get_input_grads Get the gradients to the inputs, computed in the previous backward computation.


其它 说明 BaseModule.get_states Get states from all devices BaseModule.set_states Set value for states. BaseModule.install_monitor Install monitor on all executors. BaseModule.symbol Get the symbol associated with this module.

Other build-in modules

Besides the basic interface defined in BaseModule, each module class supports additional functionality. We summarize them in this section.

Class Module

类模块 说明 Module.load Create a model from previously saved checkpoint. Module.save_checkpoint Save current progress to checkpoint. Module.reshape Reshape the module for new input shapes. Module.borrow_optimizer Borrow optimizer from a shared module. Module.save_optimizer_states Save optimizer (updater) state to file Module.load_optimizer_states Load optimizer (updater) state from file

Class BucketModule

BucketModule 说明 BucketModule.switch_bucket

Class SequentialModule

SequentialModule 说明 SequentialModule.add Add a module to the chain.


Basic Push and Pull

Provides basic operation over multiple devices (GPUs) on a single device.


Let’s consider a simple example. It initializes a (int, NDArray) pair into the store, and then pulls the value out.

>>> kv = mx.kv.create('local') # create a local kv store.>>> shape = (2,3)>>> kv.init(3, mx.nd.ones(shape)*2)>>> a = mx.nd.zeros(shape)>>> kv.pull(3, out = a)>>> print a.asnumpy()[[ 2.  2.  2.] [ 2.  2.  2.]]

Push, Aggregation, and Updater

For any key that’s been initialized, you can push a new value with the same shape to the key, as follows:

>>> kv.push(3, mx.nd.ones(shape)*8)>>> kv.pull(3, out = a) # pull out the value>>> print a.asnumpy()[[ 8.  8.  8.] [ 8.  8.  8.]]

The data that you want to push can be stored on any device. Furthermore, you can push multiple values into the same key, where KVStore first sums all of these values, and then pushes the aggregated value, as follows:

>>> gpus = [mx.gpu(i) for i in range(4)]>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]>>> kv.push(3, b)>>> kv.pull(3, out = a)>>> print a.asnumpy()[[ 4.  4.  4.] [ 4.  4.  4.]]

For each push command, KVStore applies the pushed value to the value stored by an updater. The default updater is ASSIGN. You can replace the default to control how data is merged.

>>> def update(key, input, stored):>>>     print "update on key: %d" % key>>>     stored += input * 2>>> kv._set_updater(update)>>> kv.pull(3, out=a)>>> print a.asnumpy()[[ 4.  4.  4.] [ 4.  4.  4.]]>>> kv.push(3, mx.nd.ones(shape))update on key: 3>>> kv.pull(3, out=a)>>> print a.asnumpy()[[ 6.  6.  6.] [ 6.  6.  6.]]


You’ve already seen how to pull a single key-value pair. Similar to the way that you use the push command, you can pull the value into several devices with a single call.

>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]>>> kv.pull(3, out = b)>>> print b[1].asnumpy()[[ 6.  6.  6.] [ 6.  6.  6.]]

List Key-Value Pairs

All of the operations that we’ve discussed so far are performed on a single key. KVStore also provides the interface for generating a list of key-value pairs. For a single device, use the following:

>>> keys = [5, 7, 9]>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))>>> kv.push(keys, [mx.nd.ones(shape)]*len(keys))update on key: 5update on key: 7update on key: 9>>> b = [mx.nd.zeros(shape)]*len(keys)>>> kv.pull(keys, out = b)>>> print b[1].asnumpy()[[ 3.  3.  3.] [ 3.  3.  3.]]

For multiple devices:

Data Loading API¶


This document summeries supported data formats and iterator APIs to read the data including

概述 说明 Data iterators for common data formats. mxnet.recordio Read and write for the RecordIO data format. mxnet.image Read invidual image files and perform augmentations.

It will also show how to write an iterator for a new data format.

A data iterator reads data batch by batch.

>>> data = mx.nd.ones((100,10))>>> nd_iter =, batch_size=25)>>> for batch in nd_iter:...     print([<NDArray 25x10 @cpu(0)>][<NDArray 25x10 @cpu(0)>][<NDArray 25x10 @cpu(0)>][<NDArray 25x10 @cpu(0)>]

If nd_iter.reset() is called, then reads the data again from beginning.

In addition, an iterator provides information about the batch, including the shapes and name.

>>> nd_iter ={'data':mx.nd.ones((100,10))},...                             label={'softmax_label':mx.nd.ones((100,))},...                             batch_size=25)>>> print(nd_iter.provide_data)[DataDesc[data,(25, 10L),<type 'numpy.float32'>,NCHW]]>>> print(nd_iter.provide_label)[DataDesc[softmax_label,(25,),<type 'numpy.float32'>,NCHW]]

So this iterator can be used to train a symbol whose input data variable has name data and input label variable has name softmax_label.

>>> data = mx.sym.Variable('data')>>> label = mx.sym.Variable('softmax_label')>>> fullc = mx.sym.FullyConnected(data=data, num_hidden=1)>>> loss = mx.sym.SoftmaxOutput(data=data, label=label)>>> mod = mx.mod.Module(loss)>>> print(mod.data_names)['data']>>> print(mod.label_names)['softmax_label']>>> mod.bind(data_shapes=nd_iter.provide_data, label_shapes=nd_iter.provide_label)

Then we can call, num_epoch=2) to train loss by 2 epochs.

Data iterators

数据迭代器 说明 io.NDArrayIter Iterating on either mx.nd.NDArray or numpy.ndarray. io.CSVIter Iterating on CSV files io.ImageRecordIter Iterating on image RecordIO files io.ImageRecordUInt8Iter Iterating on image RecordIO files io.MNISTIter Iterating on the MNIST dataset. recordio.MXRecordIO Read/write RecordIO format data. recordio.MXIndexedRecordIO Read/write RecordIO format data supporting random access. image.ImageIter Image data iterator with a large number of augmentation choices.

Helper classes and functions

Data structures and other iterators provided in the packages.

帮助类与函数 说明 io.DataDesc Data description io.DataBatch A data batch. io.DataIter The base class of a data iterator. io.ResizeIter Resize a data iterator to a given number of batches. io.PrefetchingIter Performs pre-fetch for other data iterators. io.MXDataIter A python wrapper a C++ data iterator.

A list of image modification functions provided by mxnet.image.

mxnet.image 说明 image.imdecode Decode an image to an NDArray. image.scale_down Scale down crop size if it’s bigger than image size. image.resize_short Resize shorter edge to size. image.fixed_crop Crop src at fixed location, and (optionally) resize it to size. image.random_crop Randomly crop src with size. image.center_crop Centrally crop src with size. image.color_normalize Normalize src with mean and std. image.random_size_crop Randomly crop src with size. image.ResizeAug Make resize shorter edge to size augmenter. image.RandomCropAug Make random crop augmenter image.RandomSizedCropAug Make random crop with random resizing and random aspect ratio jitter augmenter. image.CenterCropAug Make center crop augmenter. image.RandomOrderAug Apply list of augmenters in random order image.ColorJitterAug Apply random brightness, contrast and saturation jitter in random order. image.LightingAug Add PCA based noise. image.ColorNormalizeAug Mean and std normalization. image.HorizontalFlipAug Random horizontal flipping. image.CastAug Cast to float32 image.CreateAugmenter Creates an augmenter list.

Functions to read and write RecordIO files.

RecordIO 说明 recordio.pack Pack a string into MXImageRecord. recordio.unpack Unpack a MXImageRecord to string. recordio.unpack_img Unpack a MXImageRecord to image. recordio.pack_img Pack an image into MXImageRecord.

Develop a new iterator

Writing a new data iterator in Python is straightforward. Most MXNet training/inference program accepts an iteratable object with provide_data and provide_label properties. This tutorial how to write an iterator from scratch.

The following example demonstrates how to combine multiple data iterators into a single one. It can be used for multiple modality training such as image captioning, in which images can be read byImageRecordIter while documents by CSVIter

class MultiIter:    def __init__(self, iter_list):        self.iters = iter_list    def next(self):        batches = [ for i in self.iters]        return DataBatch(data=[* for b in batches],                         label=[*b.label for b in batches])    def reset(self):        for i in self.iters:            i.reset()    @property    def provide_data(self):        return [*i.provide_data for i in self.iters]    @property    def provide_label(self):        return [*i.provide_label for i in self.iters]iter = MultiIter(['image.rec'),'txt.csv')])

Parsing and another pre-processing such as augmentation may be expensive. If performance is critical, we can implement a data iterator in C++. Refer to src/io for examples.

Optimization: initialize and update weights


This document summaries the APIs used to initialize and update the model weights during training

概述 说明 mxnet.initializer Weight initializer. mxnet.optimizer Weight updating functions. mxnet.lr_scheduler Scheduling learning rate.

and how to develop a new optimization algorithm in MXNet.

Assume there there is a pre-defined Symbol and a Module is created for it

>>> data = mx.symbol.Variable('data')>>> label = mx.symbol.Variable('softmax_label')>>> fc = mx.symbol.FullyConnected(data, name='fc', num_hidden=10)>>> loss = mx.symbol.SoftmaxOutput(fc, label, name='softmax')>>> mod = mx.mod.Module(loss)>>> mod.bind(data_shapes=[('data', (128,20))], label_shapes=[('softmax_label', (128,))])

Next we can initialize the weights with values sampled uniformly from [-1,1]:

>>> mod.init_params(mx.initializer.Uniform(scale=1.0))

Then we will train a model with standard SGD which decreases the learning rate by multiplying 0.9 for each 100 batches.

>>> lr_sch = mx.lr_scheduler.FactorScheduler(step=100, factor=0.9)>>> mod.init_optimizer(...     optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ('lr_scheduler', lr_sch)))

Finally run to start training.

The mxnet.initializer package

The base class Initializer defines the default behaviors to initialize various parameters, such as set bias to 1, except for the weight. Other classes then defines how to initialize the weight.

mxnet.initializer 说明 Initializer The base class of an initializer. Uniform Initialize the weight with value uniformly sampled from [-scale, scale]. Normal Initialize the weight with value sampled according to normal(0, sigma). Load Initialize by loading data from file or dict. Mixed Initialize parameters using multiple initializers. Zero Initialize the weight to 0. One Initialize the weight to 1. Constant Initialize the weight to a scalar value. Orthogonal Initialize weight as orthogonal matrix. Xavier Initialize the weight with Xavier or other similar schemes. MSRAPrelu Initialize the weight according to a MSRA paper. Bilinear Initialize weight for upsampling layers. FusedRNN Initialize parameters for fused rnn layers.

The mxnet.optimizer package

The base class Optimizer accepts commonly shared arguments such as learning_rate and defines the interface. Each other class in this package implements one weight updating function.

mxnet.optimizer 说明 Optimizer The base class inherited by all optimizers. SGD The SGD optimizer with momentum and weight decay. NAG Nesterov accelerated SGD. RMSProp The RMSProp optimizer. Adam The Adam optimizer. AdaGrad AdaGrad optimizer AdaDelta The AdaDelta optimizer. DCASGD The DCASGD optimizer SGLD Stochastic Gradient Riemannian Langevin Dynamics.

The mxnet.lr_scheduler package

The base class LRScheduler defines the interface, while other classes implement various schemes to change the learning rate during training.

mxnet.lr_scheduler 说明 LRScheduler Base class of a learning rate scheduler. FactorScheduler Reduce the learning rate by a factor for every n steps. MultiFactorScheduler Reduce the learning rate by given a list of steps.

Implement a new algorithm

Most classes listed in this document are implemented in Python by using NDArray. So implementing new weight updating or initialization functions is straightforward.

For initializer, create a subclass of Initializer and define the _init_weight method. We can also change the default behaviors to initialize other parameters such as _init_bias. for examples.

For optimizer, create a subclass of Optimizer and implement two methods create_state and update. Also add @mx.optimizer.Optimizer.register before this class. See for examples.

For lr_scheduler, create a subclass of LRScheduler and then implement the __call__ method. See for examples.



0 0