caffe源码分析--Blob类代码研究

来源:互联网 发布:阅读题软件 编辑:程序博客网 时间:2024/04/30 16:20

作者:linger

转自:http://blog.csdn.net/lingerlanlan/article/details/24379689



数据成员

shared_ptr<SyncedMemory>data_;//data数据,指向SyncedMemory的智能指针

shared_ptr<SyncedMemory>diff_;//表示“差”,用于更新data_

intnum_;

intchannels_;

intheight_;

intwidth_;

intcount_;



构造函数

Blob():num_(0),channels_(0),height_(0),width_(0),count_(0),data_(),diff_(){}

功能:简单的初始化


explicitBlob(constintnum,constintchannels,constintheight,constintwidth);

功能:调用Reshape函数,初始化数据成员

template<typenameDtype>

Blob<Dtype>::Blob(constintnum,constintchannels,constintheight,

constintwidth) {

Reshape(num,channels, height, width);

}


析构函数

virtual~Blob(){}

功能:啥都没做?






voidReshape(constintnum,constintheight,

constintwidth,constintchannels);

功能:初始化数据成员,智能指针指向SyncedMemory对象。此时SyncedMemory对象其实并没有为自己的“数据”申请内存,只是自己“数据”的大小(size)。

template<typenameDtype>

voidBlob<Dtype>::Reshape(constintnum,constintchannels,constintheight,

constintwidth) {

CHECK_GE(num,0);

CHECK_GE(channels,0);

CHECK_GE(height,0);

CHECK_GE(width,0);

num_= num;

channels_= channels;

height_= height;

width_= width;

count_=num_*channels_*height_*width_;

if(count_){

data_.reset(newSyncedMemory(count_*sizeof(Dtype)));

diff_.reset(newSyncedMemory(count_*sizeof(Dtype)));

}else{

data_.reset(reinterpret_cast<SyncedMemory*>(NULL));

diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));

}

}



成员访问函数

功能:就是返回一些成员变量

inlineintnum()const{returnnum_;}

inlineintchannels()const{returnchannels_;}

inlineintheight()const{returnheight_;}

inlineintwidth()const{returnwidth_;}

inlineintcount()const{returncount_;}

inlineintoffset(constintn,constintc = 0, constinth = 0,constintw = 0) const{

return((n * channels_+ c) *height_+ h) *width_+ w;

//计算偏移量,因为数据在内存是一维数组形式的,所以需要计算偏移量来访问

}


数据”指针返回函数

功能:其实这些函数就是调用SyncedMemory的函数,来返回数据的指针

constDtype*cpu_data()const;

constDtype*gpu_data()const;

constDtype*cpu_diff()const;

constDtype*gpu_diff()const;

Dtype*mutable_cpu_data();

Dtype*mutable_gpu_data();

Dtype*mutable_cpu_diff();

Dtype*mutable_gpu_diff();


inlineDtypedata_at(constintn,constintc,constinth,

constintw)const{

//cpu访问数据data

return*(cpu_data()+ offset(n, c, h, w));

}


inlineDtypediff_at(constintn,constintc,constinth,

constintw)const{

//cpu访问数据diff

return*(cpu_diff() + offset(n, c, h, w));

}



函数voidUpdate()

功能:更新data_的数据,就是减去diff_的数据。



template<typenameDtype>

voidBlob<Dtype>::Update(){

//We will perform update based on where the data is located.

switch(data_->head()){

caseSyncedMemory::HEAD_AT_CPU:

//perform computation on CPU

caffe_axpy<Dtype>(count_,Dtype(-1),

reinterpret_cast<constDtype*>(diff_->cpu_data()),

reinterpret_cast<Dtype*>(data_->mutable_cpu_data()));

//math_functions.cpp可以找到该函数的实现,其实这函数也是封装了mkl的函数。这里调用是为了实现了两个向量的减法。

break;

caseSyncedMemory::HEAD_AT_GPU:

caseSyncedMemory::SYNCED:

//perform computation on GPU

caffe_gpu_axpy<Dtype>(count_,Dtype(-1),

reinterpret_cast<constDtype*>(diff_->gpu_data()),

reinterpret_cast<Dtype*>(data_->mutable_gpu_data()));

//math_functions.cpp可以找到该函数的实现,其实这函数也是封装了cublas的函数。这里调用是为了实现了两个向量的减法。

break;

default:

LOG(FATAL)<<"Syncedmemnot initialized.";

}

}



函数voidCopyFrom(constBlob<Dtype>&source,boolcopy_diff = false,boolreshape =false);

功能:从source拷贝数据。copy_diff作为标志来区分是拷贝data还是拷贝diff

template<typenameDtype>

voidBlob<Dtype>::CopyFrom(constBlob&source,boolcopy_diff,boolreshape) {

if(num_!= source.num() || channels_!= source.channels() ||

height_!= source.height() || width_!= source.width()) {

if(reshape) {

Reshape(source.num(),source.channels(), source.height(), source.width());

}else{

LOG(FATAL)<<"Tryingto copy blobs of different sizes.";

}

}

switch(Caffe::mode()){

caseCaffe::GPU:

if(copy_diff){

CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(),source.gpu_diff(),

sizeof(Dtype)*count_,cudaMemcpyDeviceToDevice));

}else{

CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(),source.gpu_data(),

sizeof(Dtype)*count_,cudaMemcpyDeviceToDevice));

}

break;

caseCaffe::CPU:

if(copy_diff){

memcpy(diff_->mutable_cpu_data(),source.cpu_diff(),

sizeof(Dtype)*count_);

}else{

memcpy(data_->mutable_cpu_data(),source.cpu_data(),

sizeof(Dtype)*count_);

}

break;

default:

LOG(FATAL)<<"Unknowncaffemode.";

}

}




函数voidFromProto(constBlobProto&proto);

功能:从proto读数据进来,其实就是反序列化

template<typenameDtype>

voidBlob<Dtype>::FromProto(constBlobProto&proto){

Reshape(proto.num(),proto.channels(),proto.height(),proto.width());

//copy data

Dtype*data_vec = mutable_cpu_data();

for(inti = 0; i < count_;++i) {

data_vec[i]=proto.data(i);

}

if(proto.diff_size()> 0) {

Dtype*diff_vec = mutable_cpu_diff();

for(inti = 0; i < count_;++i) {

diff_vec[i]=proto.diff(i);

}

}

}



函数voidToProto(BlobProto*proto,boolwrite_diff = false)const;

功能:序列化到proto保存

template<typenameDtype>

voidBlob<Dtype>::ToProto(BlobProto*proto,boolwrite_diff)const{

proto->set_num(num_);

proto->set_channels(channels_);

proto->set_height(height_);

proto->set_width(width_);

proto->clear_data();

proto->clear_diff();

constDtype*data_vec = cpu_data();

for(inti = 0; i < count_;++i) {

proto->add_data(data_vec[i]);

}

if(write_diff) {

constDtype*diff_vec = cpu_diff();

for(inti = 0; i < count_;++i) {

proto->add_diff(diff_vec[i]);

}

}

}

0 0
原创粉丝点击