MXNet调研之2--python API
来源:互联网 发布:建筑大师案例分析 知乎 编辑:程序博客网 时间:2024/05/27 00:47
0. python API http://mxnet.io/api/python/
1. NDArray
A NDArray is a multidimensional container of items of the same type and size.
>>> x = mx.nd.array([[1, 2, 3], [4, 5, 6]])>>> type(x)<class 'mxnet.ndarray.NDArray'>>>> x.shape(2, 3)>>> y = x + mx.nd.ones(x.shape)*3>>> print(y.asnumpy())[[ 4. 5. 6.] [ 7. 8. 9.]]>>> z = y.as_in_context(mx.gpu(0))>>> print(z)<NDArray 2x3 @gpu(0)>
2. Symbol
A symbol declares computation. It is composited by operators, such as simple matrix operations (e.g. “+”), or a neural network layer (e.g. convolution layer). We can bind data to a symbol to execute the computation
>>> a = mx.sym.Variable('a')>>> b = mx.sym.Variable('b')>>> c = 2 * a + b>>> type(c)<class 'mxnet.symbol.Symbol'>>>> e = c.bind(mx.cpu(), {'a': mx.nd.array([1,2]), 'b':mx.nd.array([2,3])})>>> y = e.forward()>>> y[<NDArray 2 @cpu(0)>]>>> y[0].asnumpy()array([ 4., 7.], dtype=float32)
bind(data_shapes, label_shapes=None, for_training=True, inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req=’write’)
Bind the symbols to construct executors. This is necessary before one can perform computation with the module.
Parameters:
- data_shapes (list of (str, tuple)) – Typically is data_iter.provide_data.
- “label_shapes (list of (str, tuple)) – Typically is data_iter.provide_label.
- for_training (bool) – Default is True. Whether the executors should be bind for training.
- input_need_grad (bool) – Default is False. Whether the gradients to the input data need to be computed. Typically this is not needed. But this might be needed when implementing composition of modules.
force_rebind” (bool) – Default is False. This function does nothing if the executors are already binded. But with this True, the executors will be forced to rebind. - shared_module (Module) – Default is None. This is used in bucketing. When not None, the shared module essentially corresponds to a different bucket – a module with different symbol but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
- grad_req (str, list of str, dict of str to str) – Requirement for gradient accumulation. Can be ‘write’, ‘add’, or ‘null’ (default to ‘write’). Can be specified globally (str) or for each argument (list, dict).
3. Module
The module API, defined in the module (or simply mod) package, provides an intermediate and high-level interface for performing computation with a Symbol. One can roughly think a module is a machine which can execute a program defined by a Symbol.
The class module.Module is a commonly used module, which accepts a Symbol as the input:
data = mx.symbol.Variable('data')fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=10)out = mx.symbol.SoftmaxOutput(fc2, name = 'softmax')mod = mx.mod.Module(out) # create a module by given a Symbol
Assume there is a valid MXNet data iterator data. We can initialize the module:
mod.bind(data_shapes=data.provide_data, label_shapes=data.provide_label) # create memory by given input shapesmod.init_params() # initial parameters with the default random initializerNow the module is able to compute. We can call high-level API to train and predict:mod.fit(data, num_epoch=10, ...) # trainmod.predict(new_data) # predict on new data
or use intermediate APIs to perform step-by-step computations
mod.forward(data_batch) # forward on the provided data batchmod.backward() # backward to calculate the gradientsmod.update() # update parameters using the default optimizer
A detailed tutorial is available at http://mxnet.io/tutorials/python/module.html.
4. KVStore
Provides basic push&pull operations over multiple devices (GPUs) on a single device.
# Initialation>>> kv = mx.kv.create('local') # create a local kv store.>>> shape = (2,3) # 2 rows 3 cols>>> kv.init(3, mx.nd.ones(shape)*2)>>> a = mx.nd.zeros(shape)>>> kv.pull(3, out = a)>>> print a.asnumpy()[[ 2. 2. 2.] [ 2. 2. 2.]]# Push, Aggregation and Updater>>> kv.push(3, mx.nd.ones(shape)*8)>>> kv.pull(3, out = a) # pull out the value>>> print a.asnumpy()[[ 8. 8. 8.] [ 8. 8. 8.]]# you can push multiple values into the same key, # where KVStore first sums all of these values, # and then pushes the aggregated value, as follows:>>> cpus = [mx.cpu(i) for i in range(4)]>>> b = [mx.nd.ones(shape, cpu) for cpu in cpus]>>> kv.push(3, b)>>> kv.pull(3, out = a)>>> print a.asnumpy()[[ 4. 4. 4.] [ 4. 4. 4.]]# You can replace the default to control how data is merged>>> def update(key, input, stored):>>> print "update on key: %d" % key>>> stored += input * 2>>> kv._set_updater(update)>>> kv.pull(3, out=a)>>> print a.asnumpy()[[ 4. 4. 4.] [ 4. 4. 4.]]>>> kv.push(3, mx.nd.ones(shape))update on key: 3>>> kv.pull(3, out=a)>>> print a.asnumpy()[[ 6. 6. 6.] [ 6. 6. 6.]]
5. Data Loading
class mxnet.io.NDArrayIter(data, label=None, batch_size=1, shuffle=False, last_batch_handle=’pad’, data_name=’data’, label_name=’softmax_label’)
Iterating on either mx.nd.NDArray or numpy.ndarray.
Parameters:
- data (array or list of array or dict of string to array) – Input data
- label (array or list of array or dict of string to array, optional) – Input label
- batch_size (int) – Batch Size
shuffle (bool, optional) – Whether to shuffle the data - last_batch_handl (str, optional) – How to handle the last batch, can be ‘pad’, ‘discard’ or ‘roll_over’. ‘roll_over’ is intended for training and can cause problems if used for prediction.
- data_name (str, optional) – The data name
- label_name (str, optional) – The label name
A data iterator reads data batch by batch:
>>> data = mx.nd.ones((100,10))>>> nd_iter = mx.io.NDArrayIter(data, batch_size=25)>>> for batch in nd_iter:... print(batch.data)[<NDArray 25x10 @cpu(0)>][<NDArray 25x10 @cpu(0)>][<NDArray 25x10 @cpu(0)>][<NDArray 25x10 @cpu(0)>]
In addition, an iterator provides information about the batch, including the shapes and name:
>>> nd_iter = mx.io.NDArrayIter(data={'data':mx.nd.ones((100,10))},... label={'softmax_label':mx.nd.ones((100,))},... batch_size=25)>>> print(nd_iter.provide_data)[DataDesc[data,(25, 10L),<type 'numpy.float32'>,NCHW]]>>> print(nd_iter.provide_label)[DataDesc[softmax_label,(25,),<type 'numpy.float32'>,NCHW]]
6. Optimization: initialize and update weights
- MXNet调研之2--python API
- MXNet调研之1--编译和安装
- mxnet系列之mxnet介绍
- MXNet的Model API
- Http Api调研分享
- 高德API调研
- wso2 API 调研
- mxnet系列之-mshadow
- MXNet SSD之multibox_target
- OpenCV 之 Python API
- mxnet
- MXNet
- MXNet
- MXNet
- mxnet实战之艺术画
- mxnet代码解析之mshadow
- mxnet代码解析之nnvm
- mxnet系列之 c++11
- eclipse NDK 编译
- 剑指offer39:两个链表中的第一个公共节点
- 网站性能测试
- 从零开始的spring之在spring 中加入 Quartz
- 【AC梦工厂】cf406A(div2) 扩展欧几里德
- MXNet调研之2--python API
- storm
- Qt信息框
- Javascript数据结构算法之排序二(希尔排序,归并排序,快速排序)
- linux下mysql oom killer
- ArrayList自动扩容解析
- Redis系列(二)--缓存设计(整表缓存以及排行榜缓存方案实现)
- JSON简介以及用法汇总
- 1002. 写出这个数 (20)