博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
MXNet 源码解读系列之一 C++端如何解析NDArray参数文件
阅读量:7045 次
发布时间:2019-06-28

本文共 12612 字,大约阅读时间需要 42 分钟。

  hot3.png

本文相关代码:

      要想弄清楚MXNet 是如何解析参数文件,并从中提取预训练好的权值,首先第一步要看

MXNet Python端是如何是调用C接口来完成读取NDArray参数文件的。

      这部分代码见源码  第149行:

def load(fname):    """Loads an array from file.    See more details in ``save``.    Parameters    ----------    fname : str        The filename.    Returns    -------    list of NDArray, RowSparseNDArray or CSRNDArray, or \    dict of str to NDArray, RowSparseNDArray or CSRNDArray        Loaded data.    """    if not isinstance(fname, string_types):        raise TypeError('fname required to be a string')    out_size = mx_uint()    out_name_size = mx_uint()    handles = ctypes.POINTER(NDArrayHandle)()    names = ctypes.POINTER(ctypes.c_char_p)()    check_call(_LIB.MXNDArrayLoad(c_str(fname),                                                                           ctypes.byref(out_size),                                  ctypes.byref(handles),                                  ctypes.byref(out_name_size),                                  ctypes.byref(names)))    if out_name_size.value == 0:        return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]    else:        assert out_name_size.value == out_size.value        return dict(            (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))            for i in range(out_size.value))

       这个 load 函数接收参数路径作为输入,然后根据参数文件中有没有包含参数的名字选择返回

NDArray参数数组或者字典。然后可以看到是调用了 MXNDArrayLoad 这个C接口函数,这个函数的

代码见  第308行:

int MXNDArrayLoad(const char* fname,                  mx_uint *out_size,                  NDArrayHandle** out_arr,                  mx_uint *out_name_size,                  const char*** out_names) {  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();  ret->ret_vec_str.clear();  API_BEGIN();  std::vector
data; std::vector
&names = ret->ret_vec_str; { std::unique_ptr
fi(dmlc::Stream::Create(fname, "r")); mxnet::NDArray::Load(fi.get(), &data, &names); } ret->ret_handles.resize(data.size()); for (size_t i = 0; i < data.size(); ++i) { NDArray *ptr = new NDArray(); *ptr = data[i]; ret->ret_handles[i] = ptr; } ret->ret_vec_charp.resize(names.size()); for (size_t i = 0; i < names.size(); ++i) { ret->ret_vec_charp[i] = names[i].c_str(); } *out_size = static_cast
(data.size()); *out_arr = dmlc::BeginPtr(ret->ret_handles); *out_name_size = static_cast
(names.size()); *out_names = dmlc::BeginPtr(ret->ret_vec_charp); API_END();}

     然后可以看到最核心的代码就是第319行调用了NDArray类的静态Load函数获得参数的名字和

内容,Load函数具体实现见: 第 1812行:

void NDArray::Load(dmlc::Stream* fi,                   std::vector
* data, std::vector
* keys) { uint64_t header, reserved; CHECK(fi->Read(&header)) << "Invalid NDArray file format"; CHECK(fi->Read(&reserved)) << "Invalid NDArray file format"; CHECK(header == kMXAPINDArrayListMagic) << "Invalid NDArray file format"; CHECK(fi->Read(data)) << "Invalid NDArray file format"; CHECK(fi->Read(keys)) << "Invalid NDArray file format"; CHECK(keys->size() == 0 || keys->size() == data->size()) << "Invalid NDArray file format";}

        从这里读取内容的过程可以大概看出NDArray参数文件存储的内容的顺序是什么了,首先是会

存两个uint64_t类型的数字,然后就是NDArray数组,接着是每个NDArray对应的名字的数组。

        好了接下来就是解读源码中是如何从Stream中解析出内容的,首先我们来看下Stream类的

Read函数,具体见  第435行:

template
inline bool Stream::Read(T *out_data) { return serializer::Handler
::Read(this, out_data);}

        这里可以看到,Read 函数内部又调用了 Handler这个类的Read静态函数,这个静态函数对应的

代码见  第262行:

inline static bool Read(Stream *strm, T *data) {    return IfThenElse
::value, PODHandler
, IfThenElse
::value, SaveLoadClassHandler
, UndefinedSerializerFor
, T>, T> ::Read(strm, data); }};

        这里代码我第一次看的时候有点蒙,后来仔细研究了下也看懂了。首先我们要看 IfThenElse

是什么东西,这里还是看到 io.h 的第 38 到 66行:

//! \cond Doxygen_Suppress/*! * \brief Serializer that redirect calls by condition * \tparam cond the condition * \tparam Then the serializer used for then condition * \tparam Else the serializer used for else condition * \tparam Return the type of data the serializer handles */template
struct IfThenElse;template
struct IfThenElse
{ inline static void Write(Stream *strm, const T &data) { Then::Write(strm, data); } inline static bool Read(Stream *strm, T *data) { return Then::Read(strm, data); }};template
struct IfThenElse
{ inline static void Write(Stream *strm, const T &data) { Else::Write(strm, data); } inline static bool Read(Stream *strm, T *data) { return Else::Read(strm, data); }};

    这里可以看到 IfThenElse 就是一个结构体,有四个模板参数,意思很明显了,如果第一个参数

为true,则会调用Then这个类的Read静态函数,如果第一个参数为false,则会调用Else这个类的

Read静态函数。看完 IfThenElse 的定义之后,我们看回 262 行的Read函数就很清楚了,

inline static bool Read(Stream *strm, T *data) {    return IfThenElse
::value, PODHandler
, IfThenElse
::value, SaveLoadClassHandler
, UndefinedSerializerFor
, T>, T> ::Read(strm, data); }};

    意思就是,如果 dmlc::is_pod<T>::value 这个值为 true,那么就会调用 PODHandler 的Read

函数,否则就会走到下一个条件判断,下一个条件判断是当 dmlc::has_saveload<T>::value 这个值

为true的话就调用 SaveLoadClassHandler 的 Read 静态函数,否则就走到 UndefinedSerializerFor。

好了,那么现在就是要看具体走了哪个分支,首先我们要知道T在运行时时什么类型,看回上面的

NDArray 的 Load 函数,知道了首先读取得两个数字的类型是 uint64_t,接着跳转到

源码 ,看第126和第152行:

/*! \brief macro to quickly declare traits information */#define DMLC_DECLARE_TRAITS(Trait, Type, Value)       \  template<>                                          \  struct Trait
{ \ static const bool value = Value; \ }DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);

        很明显可以看到,dmlc::is_pod<uint64_t>::value 的值为 true,因此会调用 PODHandler 的

Read 函数,代码:

/*! \brief Serializer for POD(plain-old-data) data */template
struct PODHandler { inline static void Write(Stream *strm, const T &data) { strm->Write(&data, sizeof(T)); } inline static bool Read(Stream *strm, T *dptr) { return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*) }};

        PODHandler 的Read函数就是调用 Stream 的Read,这里如果读者想再详细了解 Stream 类

Read 函数的工作原理可以自己再去细看,不过对于本文来说,到这里知道了会根据T的字节数读取

内容到dptr里面就够了。

        Ok,现在已经读取完两个数字 header, reserved,然后就是读 NDArray Vector 了,然后这里

还是跳转到,调用 Handler<T>::Read 函数,不过这里和读数字不一样的地方在于,这里传入的模板

参数是vector<NDArray>,所以调用的是下面这个Handler定义的Read函数:

//! \cond Doxygen_Suppresstemplate
struct Handler
> { inline static void Write(Stream *strm, const std::vector
&data) { IfThenElse
::value, PODVectorHandler
, ComposeVectorHandler
, std::vector
> ::Write(strm, data); } inline static bool Read(Stream *strm, std::vector
*data) { return IfThenElse
::value, PODVectorHandler
, ComposeVectorHandler
, std::vector
> ::Read(strm, data); }};

    然后这里的判断分支是会调用 ComposeVectorHandler 的 Read 函数:

/*! * \brief Serializer handler for std::vector
where T can be composed type * \tparam T element type */template
struct ComposeVectorHandler { inline static void Write(Stream *strm, const std::vector
&vec) { uint64_t sz = static_cast
(vec.size()); strm->Write(&sz, sizeof(sz)); for (size_t i = 0; i < vec.size(); ++i) { Handler
::Write(strm, vec[i]); } } inline static bool Read(Stream *strm, std::vector
*out_vec) { uint64_t sz; if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false; size_t size = static_cast
(sz); out_vec->resize(size); for (size_t i = 0; i < size; ++i) { if (!Handler
::Read(strm, &(*out_vec)[i])) return false; } return true; }};

        首先先读出 vector 数组的大小,然后分别读取每个 NDArray,这里在读每个 NDArray 的时候

又会调用 Handler<T>::Read 函数,这次 IfThenElse 分支判断那里会走 SaveLoadClassHandler

这个分支:

// serializer for class that have save/load functiontemplate
struct SaveLoadClassHandler { inline static void Write(Stream *strm, const T &data) { data.Save(strm); } inline static bool Read(Stream *strm, T *data) { return data->Load(strm); }};

    最后看到其实就是调用了 NDArray 类本身的 Load 函数,见源码 :

bool NDArray::Load(dmlc::Stream *strm) {  uint32_t magic;  if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;  if (magic != NDARRAY_V2_MAGIC) {    return LegacyLoad(strm, magic);  }  // load storage type  int32_t stype;  if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;  const int32_t nad = num_aux_data(static_cast
(stype)); // load storage shape TShape sshape; if (nad > 0) { if (!sshape.Load(strm)) return false; } // load shape TShape shape; if (!shape.Load(strm)) return false; if (shape.ndim() == 0) { *this = NDArray(); return true; } // load context Context ctx; if (!ctx.Load(strm)) return false; // load type flag int32_t type_flag; if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false; // load aux_types and aux_shapes std::vector
aux_types; std::vector
aux_shapes; if (nad > 0) { aux_types.resize(nad); aux_shapes.resize(nad); for (int i = 0; i < nad; ++i) { // load aux_type(i) if (strm->Read(&aux_types[i], sizeof(aux_types[i])) != sizeof(aux_types[i])) return false; // load aux_shapes(i) if (!aux_shapes[i].Load(strm)) return false; } } // load data into CPU NDArray temp; if (0 == nad) { temp = NDArray(shape, Context::CPU(), false, type_flag); } else { temp = NDArray(static_cast
(stype), shape, Context::CPU(), false, type_flag, aux_types, aux_shapes, sshape); } // load data TBlob load_data = temp.data(); size_t type_size = mshadow::mshadow_sizeof(type_flag); size_t nread = type_size * load_data.Size(); if (strm->Read(load_data.dptr_, nread) != nread) return false; // load aux_data if (nad > 0) { for (int i = 0; i < nad; ++i) { load_data = temp.aux_data(i); type_size = mshadow::mshadow_sizeof(load_data.type_flag_); nread = type_size * load_data.Size(); if (strm->Read(load_data.dptr_, nread) != nread) return false; } } if (ctx.dev_mask() == cpu::kDevMask) { *this = std::move(temp); return true; } else {#if MXNET_USE_CUDA *this = temp.Copy(ctx); return true;#else *this = std::move(temp); return true;#endif }}

    这里首先,读出一个 magic number ,如果用 V1.0 之后的MXNet版本,magic number

都是会等于  NDARRAY_V2_MAGIC,具体定义见下面:

/* magic number for ndarray version 1, with int64_t TShape */static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;/* magic number for ndarray version 2, with storage type */static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

        所以不会进入 LegacyLoad 函数,接着就是读 storage type,NDArray的类型,除了常用的

普通类型,现在也已经支持了稀疏类型:

enum NDArrayStorageType {  kUndefinedStorage = -1,  // undefined storage  kDefaultStorage,         // dense  kRowSparseStorage,       // row sparse  kCSRStorage,             // csr};

    一般来说,storage type 都是 kDefaultStorage 类型,我现在写的解析小工具里面也只考虑

了解析普通类型的NDArray,之后再改进吧。然后看到 num_aux_data函数,这个函数如果传入

普通类型则返回0,所以 nad 的值为 0。

size_t num_aux_data(NDArrayStorageType stype) {  size_t num = 0;  switch (stype) {    case kDefaultStorage: num = 0; break;    case kCSRStorage: num = 2; break;    case kRowSparseStorage: num = 1; break;     default: LOG(FATAL) << "Unknown storage type" << stype; break;  }  return num;}

    nad 值为0 的话整个代码就简洁很多了,简化之后如下:

bool NDArray::Load(dmlc::Stream *strm) {  uint32_t magic;  if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;  if (magic != NDARRAY_V2_MAGIC) {    return LegacyLoad(strm, magic);  }  // load storage type  int32_t stype;  if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;  // load shape  TShape shape;  if (!shape.Load(strm)) return false;  if (shape.ndim() == 0) {    *this = NDArray(); return true;  }  // load context  Context ctx;  if (!ctx.Load(strm)) return false;  // load type flag  int32_t type_flag;  if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;  // load data into CPU  NDArray temp;  temp = NDArray(shape, Context::CPU(), false, type_flag);    // load data  TBlob load_data = temp.data();  size_t type_size = mshadow::mshadow_sizeof(type_flag);  size_t nread = type_size * load_data.Size();  if (strm->Read(load_data.dptr_, nread) != nread) return false;}

    到这里为,大概怎么读取NDArray,相信应该挺清晰的了。

转载于:https://my.oschina.net/Ldpe2G/blog/1831462

你可能感兴趣的文章
机器学习之sklearn——EM
查看>>
tengine整合tomcat加上memcached实现高并发、负载均衡、可扩展架构
查看>>
CloudStack追求简单易用
查看>>
declare 声明Shell变量
查看>>
敏捷开发般若敏捷系列之八:敏捷的未来会怎样?
查看>>
Java 编程
查看>>
我的友情链接
查看>>
mariadb常用备份与还原工具介绍
查看>>
F5服务器负载均衡测试方案
查看>>
深入内存
查看>>
python virtualenv 需要使用系统的第三方包。
查看>>
得到spring的上下文
查看>>
向电信联通开刀背后的真相
查看>>
VII Python(7)爬虫
查看>>
android Installation error: INSTALL_FAILED_CONTAINER_ERROR错误
查看>>
数据库安全管理
查看>>
java位运算
查看>>
我的友情链接
查看>>
我的友情链接
查看>>
linux基本命令
查看>>