XGBoost解析系列-数据加载

来源:互联网 发布:中国电信网络在线测速 编辑:程序博客网 时间:2024/06/05 07:58

  • 前言
  • XGBoost数据加载
    • 1 DMatrixLoad主流程
    • 2 解析器parser构建过程
    • 3 DMatrix对象构建过程


0.前言

  本文主要介绍XGBoost中数据加载过程,主要是DMatrix::Load内容。

1. XGBoost数据加载

1.1 DMatrix::Load主流程

  数据集加载语句为:

std::shared_ptr<DMatrix> dtrain(DMatrix::Load(param.train_path, param.silent != 0, param.dsplit == 2));# 函数原型为:DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, const std::string& file_format) 

  可见,DMatrix::Load返回为DMatrix对象指针,参数传入为:1)文件URI地址,2)silent开关,为true打印统计信息,3)load_row_split为分布式开关,为true,则对数据进行分片shard。4)file_format为数据解析格式,默认参数为"auto",自动解析文件数据格式。基于mushroom.conf配置, 提供训练数据地址:xgboost目录下demo/data/agaricus.txt.train,数据为libsvm格式。DMatrix::Load主流程如下:
  1. uri带有#符号,解析出cache_file文件,考虑分布式模式的情况;
  2. 分布式模式,获取当前主机排序partid以及主机总数npart,单机下partid=1;npart=1
  3. 满足file_format == "auto" && npart == 1,检测文件是否为二进制数据文件,检查开头魔方是否为SimpleCSRSource::kMagic,若是则LoadBinary直接初始化SimpleCSRSource数据源对象source
  4. 构建解析器parser,基于工厂设计模式,基类dmlc::Parser<uint32_t>调用静态方法Create(), 尽管格式解析为”auto”,目前使用libsvm格式解析。
  5. 解析器parser构建返回DMatrix对象:dmat = DMatrix::Create(parser.get(), cache_file),核心过程,后续详解。
  6. 根据参数slient打印相关统计信息,尝试读取.group后缀的文件;.base_margin后缀的boost初始设定值;.weight样本的权重,用于代价敏感学习,不存在则跳过。

1.2 解析器parser构建过程

  该小节先对步骤4进行详解,Parser<uint32_t>继承于DataIter<RowBlock<uint32_t> >,实际上内部数据以CSR格式。基于静态方法Parser<uint32_t>::Create()构建实例,该方法调用普通方法CreateParser_方法。由于配置最后解析为libsvm格式,使用ParserFactoryReg工厂来找到对应的LibSVMParser解析器对象。具体代码参考如下:

// 通过ptype=libsvm找到已注册的解析器工厂方法,Get()->Find()后面会看到const ParserFactoryReg<IndexType>* e =      Registry<ParserFactoryReg<IndexType> >::Get()->Find(ptype);// 通过工厂方法生成解析器对象,最终调用的是CreateLibSVMParser<uint32_t>函数return (*e->body)(spec.uri, spec.args, part_index, num_parts);

  理解上述的代码需要梳理以下的类定义与宏定义过程:

ParserFactoryReg解析器工厂类定义

// 工厂模式,通过宏定义来注册组件,继承FunctionRegEntryBase// 第1个类参数必须是本身类型,第2个类参数是工厂方法。// ParserFactoryReg会绑定工厂方法类型,通过工厂方法来组件对象,即Parser<IndexType>::Factor为函数类型// 后期会调用DMLC_REGISTRY_ENABLE实例生成静态单例工厂对象。宏定义实现可扩展的工厂模式+对象单例非常精彩template<typename IndexType>struct ParserFactoryReg : public FunctionRegEntryBase<ParserFactoryReg<IndexType>, typename Parser<IndexType>::Factory> {};

FunctionRegEntryBase工厂基类模板

// FunctionRegEntryBase方法注册类模板,需要注册项类型,方法类型template<typename EntryType, typename FunctionType>class FunctionRegEntryBase { public:  std::string name;                         // 注册项名字  std::string description;                  // 注册项描述  std::vector<ParamFieldInfo> arguments;    // factory function调用参数,ParamFieldInfo结构体描述参数  FunctionType body;                        // 函数方法体,body函数指针  std::string return_type;                  // 函数返回类型  EntryType set_body();                     // 重要方法,设置工厂方法,并返回*(static_cast<EntryType*>(this))  EntryType *Find(const std::string &name); // 根据组件注册名字,找到注册项}// 注册get()方法模板特化,使用DMLC_REGISTRY_ENABLE(ParserFactoryReg<uint32_t>); 会构建Registry<ParserFactoryReg<uint32_t> >解析器工厂对象。// Registry是注册器类模板,所有注册器类都基于该模板,静态Get方法实现单例模式#define DMLC_REGISTRY_ENABLE(EntryType)                                 \  template<>                                                            \  Registry<EntryType > *Registry<EntryType >::Get() {                   \    static Registry<EntryType > inst;                                   \    return &inst;                                                       \  }                                                                     \

Registry模板类定义

template<typename EntryType>class Registry { public:  static Registry *Get();                       // 静态Get()生成注册器对象,单例模式 private:  std::vector<EntryType*> entry_list_;          // 注册项列表  std::vector<const EntryType*> const_list_;    // 注册项列表  std::map<std::string, EntryType*> fmap_;      // 注册项map索引  EntryType &__REGISTER__(const std::string& name);  // 注册方法,返回空白注册项}

DMLC注册宏定义

// 通过Get()得到工厂对象,调用__REGISTER__()获取空白注册项,需要被上层调用构建注册项#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name)          \static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \  ::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name)           \// DMLC数据解析器工厂方法注册,会调用上面DMLC_REGISTRY_REGISTER过程,得到注册项后进行设置工厂方法#define DMLC_REGISTER_DATA_PARSER(IndexType, TypeName, FactoryFunction) \DMLC_REGISTRY_REGISTER(::dmlc::ParserFactoryReg<IndexType>,           \                     ParserFactoryReg ## _ ## IndexType, TypeName)  \.set_body(FactoryFunction)

libsvm解析器组件注册例子

// ./dmlc-core/src/data.cc中libsvm解析器组件注册:特征id为uint32_t类型,数据格式libsvm,工厂方法CreateLibSVMParser<uint32_t>。DMLC_REGISTER_DATA_PARSER(uint32_t, libsvm, data::CreateLibSVMParser<uint32_t>)// 将宏定义展开,等价于全局定义变量__make_ParserFactoryReg_uint32_t_libsvm__static __attribute__((unused)) ::dmlc::ParserFactoryReg<uint32_t> & __make_ParserFactoryReg_uint32_t_libsvm__ = \::dmlc::Registry<EntryType>::Get()->__REGISTER__("libsvm").set_body(data::CreateLibSVMParser<uint32_t>)

libsvm解析器工厂方法CreateLibSVMParser详解

template<typename IndexType>Parser<IndexType> *CreateLibSVMParser(const std::string& path, const std::map<std::string, std::string>& args, unsigned part_index, unsigned num_parts) {  InputSplit* source = InputSplit::Create(path.c_str(), part_index, num_parts, "text");  // libsvm使用LibSVMParser解析  ParserImpl<IndexType> *parser = new LibSVMParser<IndexType>(source, 2); #if DMLC_ENABLE_STD_THREAD  // 如果打开,则使用ThreadedParser进行包装,具备多线程加载功能  parser = new ThreadedParser<IndexType>(parser);#endif  return parser;}

  CreateLibSVMParser()生成LibSVMParser前会构建数据分片InputSplit, 又是基于工厂模式,基类InputSplit调用静态方法Create()实例化,由于使用”text”数据类型,调用LineSplitter();生成分片处理实例对象。

  LineSplitter继承于InputSplitBase,初始化会调用InputSplitBase中的InitResetPartition方法,主要完成对数据分片的功能:
  1. Init获取所有文件以及大小,生成对应数组变量file_offset_file_offset_[i]表示文件i开始数据的字节偏移量。
  2. ResetPartition根据当前主机的rank值与所有主机数,由于数据分片采用均分方式,获取当前主机分片下的数据起始位置offset_begin_、offset_end_,通过std::upper_bound二分查找file_offset_定位跨度的文件列表索引访问file_ptr_、file_ptr_end_,期间会有数据align对齐逻辑,并通过BeforeFirst初始化文件读取状态和初始化内部变量,把文件指针挪到开始位置。

1.3 DMatrix对象构建过程

  生成parser之后需要调用解析器不断获取数据生成DMatrix内存数据对象,对应主流程步骤5,具体过程如下:

  1. 初始化空数据源SimpleCSRSource对象source,继承于类DataSource,再上层类为dmlc::DataIter<RowBatch>。主要初始化成员row_ptr_指针。SimpleCSRSource成员如下:

class SimpleCSRSource : public DataSource { public:  // MetaInfo info; // 继承于DataSource,单机全局元信息  std::vector<size_t> row_ptr_;             // 单机全局下,CSR行偏移  std::vector<RowBatch::Entry> row_data_;   // 单机全局下,CSR稀疏存储内存块,以bacth vector方式存储  RowBatch batch_;  // 调用Next会将数据指针绑定到该变量下  bool at_first_;   // 开始标记,BeforeFirst会设置true, 首次调用Next设置成false。再调用Next会退出,也就是说只有一次Next调用执行,所以说Simple  bool Next();                  // 直接将全量数据的指针赋值给RowBatch batch_;  const RowBatch &Value();      // 返回batch_变量引用}

  了解整个数据加载过程,以下3个类非常重要:

struct SparseBatch {  // 稀疏项描述  struct Entry {    bst_uint index;             // 稀疏项索引    bst_float fvalue;           // 稀疏项数值    Entry() {}    Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}    // 稀疏项比较,特征值排序会用到    inline static bool CmpValue(const Entry& a, const Entry& b) {      return a.fvalue < b.fvalue;    }  };  // batch中稀疏向量,如果把一个instance看做一行,特征项为稀疏项,一行是一个稀疏向量  struct Inst {    const Entry *data;          // 稀疏向量数据块指针    bst_uint length;            // 稀疏项个数    Inst() : data(0), length(0) {}    Inst(const Entry *data, bst_uint length) : data(data), length(length) {}    inline const Entry& operator[](size_t i) const {    // 访问第i个稀疏项      return data[i];    }  };  size_t size;                  // batch中的稀疏向量数};// 基于CSR数据格式,稀疏存储。类似地,ColBatch是基于CSC数据格式。struct RowBatch : public SparseBatch {  size_t base_rowid;            // 每个批次的rowid偏移  const size_t *ind_ptr;        // 偏移数组,大小为size+1,ind_ptr[i]表示第i行开始数据的偏移位置  const Entry *data_ptr;        // 批次数据块的指针  // batch中第i行的数据  inline Inst operator[](size_t i) const {    return Inst(data_ptr + ind_ptr[i], static_cast<bst_uint>(ind_ptr[i + 1] - ind_ptr[i]));  }};

  2. source->CopyFrom(parser);调用解析器从文件中不断批次解析出数据,存入变量row_data_中,核心逻辑如下:

while (parser->Next()) {  // 不断获取batch数据,并判断是否结束  const dmlc::RowBlock<uint32_t>& batch = parser->Value();  // 得到批次数据  if (batch.label != nullptr) {     // 解析得到label值    info.labels.insert(info.labels.end(), batch.label, batch.label + batch.size);  }  if (batch.weight != nullptr) {    // 解析得到weight值    info.weights.insert(info.weights.end(), batch.weight, batch.weight + batch.size);  }  for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {    uint32_t index = batch.index[i];    bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];    row_data_.push_back(SparseBatch::Entry(index, fvalue)); // 将值存入row_data_变量    this->info.num_col = std::max(this->info.num_col, static_cast<uint64_t>(index + 1));  }}

  解析器对象parser基于ThreadedParser,包装了LibSVMParser,两者都基于ParserImpl, 成员如下:

template <typename IndexType>class ThreadedParser : public ParserImpl<IndexType> {  Parser<IndexType> *base_;         // 绑定LibSVMParser,真正数据解析器  ThreadedIter<std::vector<RowBlockContainer<IndexType> > > iter_; // 后端线程迭代器  std::vector<RowBlockContainer<IndexType> > *tmp_;     //当前数据chunk};

  ThreadedIter基于DataIter,需要实现BeforeFirst()Next()Value()方法, 使用线程构建消费者与生产者模式,生产者预取数据。生产者内部调用ParseNext()解析数据,核心调用FillData()
    1)InputSplit对象调用NextChunk,每次最多读取16MB数据到Chunk,再转化到Blob对象中,内部使用ReadChunk将文件数据读入到内存Chunk中,使用overflow_缓存分片位置到记录结束位置以便调整下一次文件指针,因此,batch最后读出的记录是完整的。

// 内存blob,只保留数据指针和大小,真正的数据hold在Chunkstruct Blob {  void *dptr;     // 内存blob指针  size_t size;    // 内存blob大小};// 内存Chunk,保留着真正的数据struct Chunk {  char *begin;  char *end;  std::vector<size_t> data; // 每次数据最多读入16MB  explicit Chunk(size_t buffer_size)    : begin(NULL), end(NULL),      data(buffer_size + 1) {}  bool Load(InputSplitBase *split, size_t buffer_size);};

    2)使用OMP多线程并行调用ParseBlock解析数据到具备vector变量data,基于数据均分定位线程处理的数据起始位置,为了保证边际在记录中的情况,需要BackFindEndLine做指针相应偏移。ParseBlock根据行偏移开始与行结束位置,将数据解析到线程对应RowBlockContainer<IndexType>对象(*data)[tid]中,ParsePair底层函数会解析出weight、label、特征value,根据libsvm的冒号:进行token分割。 RowBlockContainer模板类定义如下,数据存储格式为CSR格式。

template<typename IndexType>struct RowBlockContainer {  std::vector<size_t> offset;       // offset[i]为第i行非0特征数据的偏移  std::vector<real_t> label;        // offset[i]为第i行的label值  std::vector<real_t> weight;       // offset[i]为第i行的weight值  std::vector<IndexType> field;     // 特征field值  std::vector<IndexType> index;     // 特征id值  std::vector<real_t> value;        // 特征数值,需结合offset  IndexType max_field;              // 最大field值  IndexType max_index;              // 最大特征id值}// RowBlock成员指针指向对应RowBlockContainer成员vector内存起始位置template<typename IndexType>struct RowBlock {  size_t size;                      // 批次大小  const size_t *offset;             // offset数据块指针  const real_t *label;              // label数据块指针  const real_t *weight;             // weight数据块指针   const IndexType *field;           // field数据块指针    const IndexType *index;           // index数据块指针  const real_t *value;              // value数据块指针};

  ThreadedParser对象调用Next()生成RowBlock<IndexType>数据,即通过获取RowBlockContainer<IndexType>数据的内存初始位置。如果消费完预取数据,则iter_.Next(&tmp_)继续解析数据到tmp_中。每次parser->Value()生成batch数据,更新MetaInfo信息labels、weights、num_nonzeronum_row。此外,稀疏特征数据项SparseBatch::Entry插入row_data_,更新对应row_ptr_偏移。可以说row_data_、row_ptr_是最终的全局数据。

  3. 使用解析器中的数据源对象source构造返回SimpleDMatrix对象,基类DMatrixSparsePageDMatrix对象需要结合cache_prefix文件,目前配置并不支持,以后再讲。至此,SimpleDMatrix提供RowIterator按照行遍历实例数据的能力,即CSR数据存储,SimpleDMatrix调用Next()直接绑定全量数据,这也是为什么叫做Simple的原因了。后续会将CSR转化为CSC格式,提供ColBatchIter遍历器,主要利用ParallelGroupBuilder对象多线程并行实现,具体细节详见此处LazyInitDMatrix过程。