Torch初学(一)

来源:互联网 发布:sql in和= 编辑:程序博客网 时间:2024/06/01 12:43
  1. Tensor

    • 多维矩阵,可以使用LongStorage

       --- creation of a 4D-tensor 4x5x6x2 z = torch.Tensor(4,5,6,2) --- for more dimensions, (here a 6D tensor) one can do: s = torch.LongStorage(6) s[1] = 4; s[2] = 5; s[3] = 6; s[4] = 2; s[5] = 7; s[6] = 3; x = torch.Tensor(s) --- The number of dimensions of a Tensor can be queried by nDimension() or dim() > x:nDimension()   6 --- Size of the i-th dimension is returned by size(i)(注意i从1开始). A LongStorage containing all the dimensions can be returned by size(). > x:size()   4   5   6   2   7   3   [torch.LongStorage of size 6]
    • 实际的数据存储在Storage中
    --- The actual data of a Tensor is contained into a Storage. It can be accessed using storage(). While the memory of a Tensor has to be contained in this unique Storage, it might not be contiguous: the first position used in the Storage is given by storageOffset() (starting at 1). And the jump needed to go from one element to another element in the i-th dimension is given by stride(i)> x = torch.Tensor(7,7,7) --- accessing the element (3,4,5) can be done by> x[3][4][5]> 0 --- or equivalently (but slowly!)> x:storage()[x:storageOffset()+(3-1)*x:stride(1)+(4-1)*x:stride(2)+(5-1)*x:stride(3)]> 0 --- One could say that a Tensor is a particular way of viewing a Storage: a Storage only represents a chunk of memory, while the Tensor interprets this chunk of memory as having dimensions > x = torch.Tensor(4,5)> s = x:storage()> for i=1,s:size() do -- fill up the Storage  s[i] = i> x -- s is interpreted by x as a 2D matrix  1   2   3   4   5  6   7   8   9  10 11  12  13  14  15 16  17  18  19  20[torch.DoubleTensor of dimension 4x5] --- Note also that in Torch7 elements in the same row [elements along the last dimension] are contiguous in memory for a matrix [tensor] > x = torch.Tensor(4,5)> i = 0> x:apply(function()  i = i + 1  return i  end)> x  1   2   3   4   5  6   7   8   9  10 11  12  13  14  15 16  17  18  19  20[torch.DoubleTensor of dimension 4x5]> x:stride() 5 1  -- element in the last dimension are contiguous![torch.LongStorage of size 2]
    • 不同Tensor的种类:一般使用DoubleTensor和FloatTensor。用户可以使用torch.Tensor创建类型独立的torch脚本,当运行时选择想要的tensor type,如torch.setdefaulttensortype(‘torch.FloatTensor’)
    ByteTensor -- contains unsigned charsCharTensor -- contains signed charsShortTensor -- contains shortsIntTensor -- contains intsLongTensor -- contains longsFloatTensor -- contains floatsDoubleTensor -- contains doubles
    • 有效的内存管理
    --- All tensor operations in this class do not make any memory copy. All these methods transform the existing tensor, or return a new tensor referencing the same storage. This magical behavior is internally obtained by good usage of the stride() and storageOffset() > x = torch.Tensor(5):zero()> x00000[torch.DoubleTensor of dimension 5]> x:narrow(1, 2, 3):fill(1)> x 0 1 1 1 0[torch.Tensor of dimension 5] --- If you really need to copy a Tensor, you can use the copy() method. Or the convenience method> y = torch.Tensor(x:size()):copy(x)> y = x:clone()> y 0 1 1 1 0[torch.Tensor of dimension 5] 
    • 创建Tensor
      • torch.Tensor()
      • torch.Tensor(tensor)
      • torch.Tensor(sz1 [,sz2 [,sz3 [,sz4]]]])
      • torch.Tensor(sizes, [strides])
      • torch.Tensor(storage, [storageOffset, sizes, [strides]])
      • torch.Tensor(storage, [storageOffset, sz1 [, st1 … [, sz4 [, st4]]]])
      • torch.Tensor(table)
    • 函数调用

      • Cloning
        • clone()
        • contiguous()
        • type(type)
        • typeAs(tensor)
        • isTensor(object)
        • byte(), char(), short(), int(), long(), float(), double()
      • Querying the size and structure
        • nDimension()
        • dim()
        • size(dim)
        • size()
        • self()
        • stride(dim)
        • stride()
        • storage()
        • isContiguous()
        • isSize(storage)
        • isSameSizeAs(tensor)
        • nElement()
        • storageOffset()
      • 访问元素
      x = torch.Tensor(3,3)i = 0; x:apply(function() i = i + 1; return i end)> x 1  2  3 4  5  6 7  8  9[torch.DoubleTensor of dimension 3x3]> x[2] -- returns row 2 4 5 6[torch.DoubleTensor of dimension 3]> x[2][3] -- returns row 2, column 36> x[{2,3}] -- another way to return row 2, column 36> x[torch.LongStorage{2,3}] -- yet another way to return row 2, column 36> x[torch.le(x,3)] -- torch.le returns a ByteTensor that acts as a mask 1 2 3[torch.DoubleTensor of dimension 3]
      • Referencing a tensor to an existing tensor or chunk of memory
        • set(tensor)
        • isSetTo(tensor)
        • set(storage, [storageOffset, sizes, [strides]])
        • set(storage, [storageOffset, sz1 [, st1 … [, sz4 [, st4]]]])
      • Copying and initializing
        • copy(tensor)
        • fill(value)
        • zero()
      • Resizing
        • resizeAs(tensor)
        • resize(sizes)
        • resize(sz1 [,sz2 [,sz3 [,sz4]]]])
      • Extracting sub-tensors
        • narrow(dim, index, size)
        • sub(dim1s, dim1e … [, dim4s [, dim4e]])
        • select(dim, index)
        • [{ dim1,dim2,… }] or [{ {dim1s,dim1e}, {dim2s,dim2e} }]
        • index(dim, index)
        • indexCopy(dim, index, tensor)
        • indexAdd(dim, index, tensor)
        • indexFill(dim, index, val)
        • gather(dim, index)
        • scatter(dim, index, src|val)
        • maskedSelect(mask)
        • maskedCopy(mask, tensor)
        • maskedFill(mask, val)
      • Search
        • nonzero(tensor)
      • Expanding/Replicating/Squeezing Tensors
        • expand([result,] sizes)
        • expandAs([result,] tensor)
        • repeatTensor([result,] sizes)
        • squeeze([dim])
      • Manipulating the tensor view
        • view([result,] tensor, sizes)
        • viewAs([result,] tensor, template)
        • transpose(dim1, dim2)
        • t()
        • permute(dim1, dim2, …, dimn)
        • unfold(dim, size, step)
      • Applying a function to a tensor
        • apply(function)
        • map(tensor, function(xs, xt))
        • map2(tensor1, tensor2, function(x, xt1, xt2))
      • Dividing a tensor into a table of tensors
        • split([result,] tensor, size, [dim])
        • chunk([result,] tensor, n, [dim])
      • LuaJIT FFI access
        • data(tensor, [asnumber])
        • cdata(tensor, [asnumber])
      • Reference counting
        • retain()
        • free()
  2. 数学操作
    • Construction or extraction functions
      • torch.cat( [res,] x_1, x_2, [dimension] )
      • torch.cat( [res,] {x_1, x_2, …}, [dimension] )
      • torch.diag([res,] x [,k])
      • torch.eye([res,] n [,m])
      • torch.histc([res,] x [,nbins, min_value, max_value])
      • torch.linspace([res,] x1, x2, [,n])
      • torch.logspace([res,] x1, x2, [,n])
      • torch.multinomial([res,], p, n, [,replacement])
      • torch.ones([res,] m [,n…])
      • torch.rand([res,] [gen,] m [,n…])
      • torch.randn([res,] [gen,] m [,n…])
      • torch.range([res,] x, y [,step])
      • torch.randperm([res,] [gen,] n)
      • torch.reshape([res,] x, m [,n…])
      • torch.tril([res,] x [,k])
      • torch.triu([res,] x, [,k])
      • torch.zeros([res,] x)
    • Element-wise Mathematical Operations
      • torch.abs([res,] x)
      • torch.sign([res,] x)
      • torch.acos([res,] x)
      • torch.asin([res,] x)
      • torch.atan([res,] x)
      • torch.ceil([res,] x)
      • torch.cos([res,] x)
      • torch.cosh([res,] x)
      • torch.exp([res,] x)
      • torch.floor([res,] x)
      • torch.log([res,] x)
      • torch.log1p([res,] x)
      • x:neg()
      • x:cinv()
      • torch.pow([res,] x, n)
      • torch.round([res,] x)
      • torch.sin([res,] x)
      • torch.sinh([res,] x)
      • torch.sqrt([res,] x)
      • torch.rsqrt([res,] x)
      • torch.tan([res,] x)
      • torch.tanh([res,] x)
      • torch.sigmoid([res,] x)
      • torch.trunc([res,] x)
      • torch.frac([res,] x)
    • Basic operations
      • equal([tensor1,] tensor2)
      • torch.add([res,] tensor, value)
      • torch.add([res,] tensor1, tensor2)
      • torch.add([res,] tensor1, value, tensor2)
      • tensor:csub(value)
      • tensor1:csub(tensor2)
      • torch.mul([res,] tensor1, value)
      • torch.clamp([res,] tensor, min_value, max_value)
      • torch.cmul([res,] tensor1, tensor2)
      • torch.cpow([res,] tensor1, tensor2)
      • torch.addcmul([res,] x [,value], tensor1, tensor2)
      • torch.div([res,] tensor, value)
      • torch.cdiv([res,] tensor1, tensor2)
      • torch.addcdiv([res,] x [,value], tensor1, tensor2)
      • torch.fmod([res,] tensor, value)
      • torch.remainder([res,] tensor, value)
      • torch.mod([res,] tensor, value)
      • torch.cfmod([res,] tensor1, tensor2)
      • torch.cremainder([res,] tensor1, tensor2)
      • torch.cmod([res,] tensor1, tensor2)
      • torch.dot(tensor1, tensor2)
      • torch.addmv([res,] [beta,] [v1,] vec1, [v2,] mat, vec2)
      • torch.addr([res,] [v1,] mat, [v2,] vec1, vec2)
      • torch.addmm([res,] [beta,] [v1,] M, [v2,] mat1, mat2)
      • torch.addbmm([res,] [v1,] M, [v2,] batch1, batch2)
      • torch.baddbmm([res,] [v1,] M, [v2,] batch1, batch2)
      • torch.mv([res,] mat, vec)
      • torch.mm([res,] mat1, mat2)
      • torch.bmm([res,] batch1, batch2)
      • torch.ger([res,] vec1, vec2)
      • torch.lerp([res,] a, b, weight)
    • Overloaded operators
      • Addition and subtraction
      • Negation
      • Multiplication
      • Division and Modulo (remainder)
    • Column or row-wise operations (dimension-wise operations)
      • torch.cross([res,] a, b [,n])
      • torch.cumprod([res,] x [,dim])
      • torch.cumsum([res,] x [,dim])
      • torch.max([resval, resind,] x [,dim])
      • torch.mean([res,] x [,dim])
      • torch.min([resval, resind,] x [,dim])
      • torch.cmax([res,] tensor1, tensor2)
      • torch.cmax([res,] tensor, value)
      • torch.cmin([res,] tensor1, tensor2)
      • torch.cmin([res,] tensor, value)
      • torch.median([resval, resind,] x [,dim])
      • torch.mode([resval, resind,] x [,dim])
      • torch.kthvalue([resval, resind,] x, k [,dim])
      • torch.prod([res,] x [,n])
      • torch.sort([resval, resind,] x [,d] [,flag])
      • torch.topk([resval, resind,] x, k, [,dim] [,dir] [,sort])
      • torch.std([res,] x, [,dim] [,flag])
      • torch.sum([res,] x)
      • torch.var([res,] x [,dim] [,flag])
    • Matrix-wide operations
      • torch.norm(x [,p] [,dim])
      • torch.renorm([res], x, p, dim, maxnorm)
      • torch.dist(x, y)
      • torch.numel(x)
      • torch.trace(x)
    • Convolution Operations
      • torch.conv2([res,] x, k, [, ‘F’ or ‘V’])
      • torch.xcorr2([res,] x, k, [, ‘F’ or ‘V’])
      • torch.conv3([res,] x, k, [, ‘F’ or ‘V’])
      • torch.xcorr3([res,] x, k, [, ‘F’ or ‘V’])
    • Eigenvalues, SVD, Linear System Solution
      • torch.gesv([resb, resa,] B, A)
      • torch.trtrs([resb, resa,] b, a [, ‘U’ or ‘L’] [, ‘N’ or ‘T’] [, ‘N’ or ‘U’])
      • torch.potrf([res,] A [, ‘U’ or ‘L’] )
      • torch.pstrf([res, piv, ] A [, ‘U’ or ‘L’] )
      • torch.potrs([res,] B, chol [, ‘U’ or ‘L’] )
      • torch.potri([res,] chol [, ‘U’ or ‘L’] )
      • torch.gels([resb, resa,] b, a)
      • torch.symeig([rese, resv,] a [, ‘N’ or ‘V’] [, ‘U’ or ‘L’])
      • torch.eig([rese, resv,] a [, ‘N’ or ‘V’])
      • torch.svd([resu, ress, resv,] a [, ‘S’ or ‘A’])
      • torch.inverse([res,] x)
      • torch.qr([q, r], x)
      • torch.geqrf([m, tau], a)
      • torch.orgqr([q], m, tau)
      • torch.ormqr([res], m, tau, mat [, ‘L’ or ‘R’] [, ‘N’ or ‘T’])
    • Logical Operations on Tensors
      • torch.lt(a, b)
      • torch.le(a, b)
      • torch.gt(a, b)
      • torch.ge(a, b)
      • torch.eq(a, b)
      • torch.ne(a, b)
      • torch.all(a)
      • torch.any(a)
  3. Storage接口
    • Constructors and Access Methods
      • torch.TYPEStorage([size [, ptr]])
      • torch.TYPEStorage(table)
      • torch.TYPEStorage(storage [, offset [, size]])
      • torch.TYPEStorage(filename [, shared [, size [, sharedMem]]])
      • self[index]
      • copy(storage)
      • fill(value)
      • resize(size)
      • size()
      • string(str)
      • string()
    • Reference counting methods
      • retain()
      • free()
  4. File
    • Read methods
    • Write methods
    • Serialization methods
      • readObject()
      • writeObject(object)
      • readString(format)
      • writeString(str)
    • General Access and Control Methods
      • ascii()
      • autoSpacing()
      • binary()
      • clearError()
      • close()
      • noAutoSpacing()
      • synchronize()
      • pedantic()
      • position()
      • quiet()
      • seek(position)
      • seekEnd()
    • File state query
      • hasError()
      • isQuiet()
      • isReadable()
      • isWritable()
      • isAutoSpacing()
      • referenced(ref)
      • isReferenced()
  5. Tester

    • Tester()
      • torch.Tester()
      • add(f, ‘name’)
      • run(testNames)
      • disable(testNames)
      • assert(condition [, message])
      • assertGeneralEq(got, expected [, tolerance] [, message])
      • eq(got, expected [, tolerance] [, message])
      • assertGeneralNe(got, unexpected [, tolerance] [, message])
      • ne(got, unexpected [, tolerance] [, message])
      • assertlt(a, b [, message])
      • assertgt(a, b [, message])
      • assertle(a, b [, message])
      • assertge(a, b [, message])
      • asserteq(a, b [, message])
      • assertne(a, b [, message])
      • assertalmosteq(a, b [, tolerance] [, message])
      • assertTensorEq(ta, tb [, tolerance] [, message])
      • assertTensorNe(ta, tb [, tolerance] [, message])
      • assertTableEq(ta, tb [, tolerance] [, message])
      • assertTableNe(ta, tb [, tolerance] [, message])
      • assertError(f [, message])
      • assertNoError(f [, message])
      • assertErrorMsg(f, errmsg [, message])
      • assertErrorPattern(f, errPattern [, message])
      • assertErrorObj(f, errcomp [, message])
      • setEarlyAbort(earlyAbort)
      • setRethrowErrors(rethrowErrors)
      • setSummaryOnly(summaryOnly)
    • TestSuite

      > test = torch.TestSuite()>> function test.myTest()>    -- ...> end>> -- ...>> function test.myTest()>    -- ...> endtorch/TestSuite.lua:16: Test myTest is already defined.
    • 应用实例

      local mytest = torch.TestSuite()local tester = torch.Tester()function mytest.testA()   local a = torch.Tensor{1, 2, 3}   local b = torch.Tensor{1, 2, 4}   tester:eq(a, b, "a and b should be equal")endfunction mytest.testB()   local a = {2, torch.Tensor{1, 2, 2}}   local b = {2, torch.Tensor{1, 2, 2.001}}   tester:eq(a, b, 0.01, "a and b should be approximately equal")endfunction mytest.testC()   local function myfunc()      return "hello " .. world   end   tester:assertNoError(myfunc, "myfunc shouldn't give an error")endtester:add(mytest)tester:run()
      Running 3 tests1/3 testB ............................................................... [PASS]2/3 testA ............................................................... [FAIL]3/3 testC ............................................................... [FAIL]Completed 3 asserts in 3 tests with 2 failures and 0 errors--------------------------------------------------------------------------------testAa and b should be equalTensorEQ(==) violation: max diff=1, tolerance=0stack traceback:        ./test.lua:8: in function <./test.lua:5>--------------------------------------------------------------------------------testCmyfunc shouldn't give an errorERROR violation: err=./test.lua:19: attempt to concatenate global 'world' (a nil value)stack traceback:        ./test.lua:21: in function <./test.lua:17>--------------------------------------------------------------------------------torch/torch/Tester.lua:383: An error was found while running tests!stack traceback:        [C]: in function 'assert'        torch/torch/Tester.lua:383: in function 'run'        ./test.lua:25: in main chunk
  6. cmdline

    • addTime([name] [,format])
    • log(filename, parameter_table)
    • option(name, default, help)
    • parse(arg)
    • silent()
    • string(prefix, params, ignore)
    • text(string)
    • 应用
    cmd = torch.CmdLine()cmd:text()cmd:text()cmd:text('Training a simple network')cmd:text()cmd:text('Options')cmd:option('-seed',123,'initial random seed')cmd:option('-booloption',false,'boolean option')cmd:option('-stroption','mystring','string option')cmd:text()-- parse input paramsparams = cmd:parse(arg)params.rundir = cmd:string('experiment', params, {dir=true})paths.mkdir(params.rundir)-- create log filecmd:log(params.rundir .. '/log', params)

    当运行th myscript.lua时会产生如下输出:

    [program started on Tue Jan 10 15:33:49 2012][command line arguments]booloption  falseseed    123rundir  experimentstroption   mystring[----------------------]booloption  falseseed    123rundir  experimentstroption   mystring

    当运行th myscript.lua -seed 456 -stroption mycustomstring时会产生如下输出:

    [program started on Tue Jan 10 15:36:55 2012][command line arguments]booloption  falseseed    456rundir  experiment,seed=456,stroption=mycustomstringstroption   mycustomstring[----------------------]booloption  falseseed    456rundir  experiment,seed=456,stroption=mycustomstringstroption   mycustomstring
  7. Random

    • Generator handling
    • Seed Handling
  8. Unility
0 0
原创粉丝点击