源码解析Mxnet Dependency Engine
KeeSkemp
8年前
<h2>Var</h2> <p>var可以看做是一个tag,用来标示每一个对象的,这样Op对对象的依赖可以简化成对var的 依赖,这样就可以构建出一个不依赖于具体的对象的通用的依赖引擎。Var是依赖引擎的关键。</p> <h3>类图</h3> <p><img src="https://simg.open-open.com/show/90ca065965bf7225291ddc8b3eedf6ee.png"></p> <p>声明:下文说到执行时,意思是Op的当前var的依赖已经就绪,因为一个op可以依赖多个 var,如果其他的Var没有就绪,那么这时op可能并没有实际运行</p> <p>Var只是一个基类,用来统一类型系统的,主要的工作在 ThreadedVar 中,每一个对象都 会有一个由 VersionedVarBlock 所组成的链表,这个链表就是一个FIFO队列。 head_ 指向的是队列的尾部, 实际是一个哨兵(空对象), head_ 这个命名有误导性, pending_write_ 指向的是最"老"的写依赖,如果没有写依赖,那么就指向 nullptr , 根据依赖引擎的特点,它实际上指向的是队列的头部, ThreadedVar 的那四个方法就是 来操作这个队列的。</p> <ol> <li>num_pending_reads_: 代表当前正在执行(还没有执行完)的读依赖的个数</li> <li>pending_write_: 代表队列中最“老”的写依赖, 它一直指向队列的头部。</li> <li>head_: 队列的尾部。</li> </ol> <p>需要注意的是,正在执行的读依赖是不在队列中的,但是正在执行的写依赖是在队列中的。</p> <h3>理解Var的队列</h3> <p>var的队列是依赖引擎的核心,下面我们来分析下各种情况下,如何修改队列的状态。</p> <ol> <li>添加读依赖: 如果前面没有写依赖,那么直接运行, 否则就插入队列的尾部(head_那一端)</li> <li>添加写依赖: 直接将依赖插入队列的尾部,并检查是不是写就绪(既没有读依赖也没有 写依赖在运行),如果是写就绪,那么就运行该依赖。</li> <li>读依赖完成</li> <li>写依赖完成</li> </ol> <p><img src="https://simg.open-open.com/show/9f937b7ae9ba11ddf612df92a34b03d3.png"></p> <p>上图中w1写依赖正在执行。</p> <p><img src="https://simg.open-open.com/show/717598cecc13ae14d295088301eba0e3.png"> 写依赖w1完成将自己移出队列,并执行写依赖w2</p> <p><img src="https://simg.open-open.com/show/a4cd3724dfd22a52c91d8240a7ee5486.png"></p> <p>写依赖w2完成后将自己移出队列,接着并行的执行读依赖r1,r2,记住正在执行的读依赖是被移出队列的, 它们的数目使用 num_pending_reads_ 跟踪的</p> <p><img src="https://simg.open-open.com/show/1f64638d0219082f781e041f61ed65d0.png"></p> <p>每一个读依赖完成都会将 num_pending_reads_ 减一,如果减为了0,那么就意味着所有 的读依赖都完成了,当r1,r2都完成后,接着执行w3写依赖。</p> <h3>添加读依赖</h3> <p>代码主要在 src/engine/Threaded_engine.cc 的 AppendReadDependency 中。</p> <pre> <strong>inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { std::lock_guard<std::mutex> lock{m_}; <strong>if (pending_write_ == nullptr) { // invariant: is_ready_to_read() CHECK_GE(num_pending_reads_, 0); // STATE CHANGE ++num_pending_reads_; // decrease wait counter opr_block->decr_wait(); } <strong>else { <strong>auto&& new_var_block = VersionedVarBlock::New(); assert(head_->next == nullptr); assert(head_->trigger == nullptr); assert(head_->write == false); // append things to next. head_->next = new_var_block; head_->trigger = opr_block; head_ = new_var_block; } }</strong></strong></strong></strong></pre> <p>代码的基本思路是这样的:检查队列中有没有写依赖,这分两种情况:</p> <ol> <li>如果没有写依赖,那么意味着,目前该Var没有依赖在执行,或者说只有读依赖在执行, 所以这个新的读依赖可以直接执行,那么它没有必要添加到队列中,只需要更新 num_pending_reads_ 就好,当然因为该op可能还依赖别的var,所以你只能调用 decr_wait ,只有当wait减为0的时候,才能开始运行。这部分代码在engine的push中。</li> <li>如果有写依赖,那么读依赖必须在写依赖的后面执行,所以需要把读依赖添加到队列的 尾部。记住 head_ 永远指向一个空的哨兵对象。</li> </ol> <h3>添加写依赖</h3> <p>代码主要在 src/engine/Threaded_engine.cc 的 AppendWriteDependency 中。</p> <pre> <strong>inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { <strong>auto&& new_var_block = VersionedVarBlock::New(); std::lock_guard<std::mutex> lock{m_}; // invariant. assert(head_->next == nullptr); assert(head_->trigger == nullptr); assert(head_->write == false); // attach to head. head_->next = new_var_block; head_->trigger = opr_block; head_->write = true; // check if it is ready to write <strong>if (pending_write_ == nullptr) { // invariant: is_ready_to_read() pending_write_ = head_; CHECK_GE(num_pending_reads_, 0); <strong>if (num_pending_reads_ == 0) { // STATE CHANGE opr_block->decr_wait(); num_pending_reads_ = kWriteTriggered; } } <strong>else { CHECK_NE(num_pending_reads_, 0); } head_ = new_var_block; }</strong></strong></strong></strong></strong></pre> <p>代码的基本思路是这样的: 将该Op放入队列的尾部,接着检查该Op的依赖有没有就绪,这 要检查Var有没有写依赖(pending_read_==nullptr)和读依赖(num_pending_read_==0)的Op 正在执行,只有二者都没有时,才能开始运行,当然你依然要检查该Op对其他的Var的依赖 有没有就绪。需要注意的一点是,即便Op的Var写依赖就绪,该Op也不会从队列中移除,只 有该Op执行完成后才会被移除,这在CompleteWriteDependency中实现。</p> <h3>读依赖完成</h3> <p>代码主要在 src/engine/Threaded_engine.cc 的 CompleteReadDependency 中。</p> <pre> <strong>template <<strong>typename Dispatcher> <strong>inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { OprBlock *trigger = nullptr; { // this is lock scope std::lock_guard<std::mutex> lock{m_}; CHECK_GT(num_pending_reads_, 0); <strong>if (--num_pending_reads_ == 0) { <strong>if (pending_write_ != nullptr) { // STATE CHANGE trigger = pending_write_->trigger; num_pending_reads_ = kWriteTriggered; } } } <strong>if (trigger != nullptr && trigger->decr_wait() == 0) { dispatcher(trigger); } }</strong></strong></strong></strong></strong></strong></pre> <p>该部分代码会在一个op运算完成后调用,代码逻辑是比较简单的,先更新 num_pending_read_ , 更新后如果该值为0,那么就意味着,所有的读依赖都已经执行完成, 这样就检查队列,若是存在写依赖,那么该写依赖就就绪了,那么Op就可以执行了(前提是 依赖的其他var也都就绪了, wait为0)。上面的dispatcher实际就是用来将Op丢入执行引擎 的,它一般是PushToExecute,这个后文会看到。</p> <h3>写依赖完成</h3> <p>代码主要在 src/engine/Threaded_engine.cc 的 CompleteWriteDependency 中。</p> <pre> <strong>template <<strong>typename Dispatcher> <strong>inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // this is lock scope VersionedVarBlock *old_pending_write, *end_of_read_chain; OprBlock* trigger_write = nullptr; { std::lock_guard<std::mutex> lock{m_}; // invariants assert(head_->next == nullptr); assert(pending_write_ != nullptr); CHECK_EQ(num_pending_reads_, kWriteTriggered); // really delete <strong>if (to_delete_) { VersionedVarBlock *head = pending_write_->next; VersionedVarBlock::Delete(pending_write_); assert(head_ == head); VersionedVarBlock::Delete(head); <strong>return true; } // detach pending write old_pending_write = pending_write_; // search for chains to trigger end_of_read_chain = old_pending_write->next; // reset to 0 pending reads num_pending_reads_ = 0; <strong>while (end_of_read_chain != head_ && end_of_read_chain->write == false) { ++num_pending_reads_; end_of_read_chain = end_of_read_chain->next; } <strong>if (end_of_read_chain == head_) { pending_write_ = nullptr; } <strong>else { // check if there is pending reads, if not trigger write assert(end_of_read_chain->write == true); pending_write_ = end_of_read_chain; <strong>if (num_pending_reads_ == 0) { // mark write as already actived in this var num_pending_reads_ = kWriteTriggered; trigger_write = end_of_read_chain->trigger; } } } // This is outside of lock scope // Be very carful, pending_write_ and num_pending_reads_ // can change now, do not reply ont the two variables. // The linked list \in [old_pending_write, end_of_read_chain) // is already detached from this Var. // So it is safe to modify these VersionedVarBlock *cur_head = old_pending_write->next; VersionedVarBlock::Delete(old_pending_write); // dispatch all the events <strong>while (cur_head != end_of_read_chain) { <strong>if (cur_head->trigger->decr_wait() == 0) { dispatcher(cur_head->trigger); } <strong>auto prev = cur_head; cur_head = cur_head->next; assert(cur_head != nullptr); VersionedVarBlock::Delete(prev); } <strong>if (trigger_write != nullptr && trigger_write->decr_wait() == 0) { dispatcher(trigger_write); } <strong>return false; }</strong></strong></strong></strong></strong></strong></strong></strong></strong></strong></strong></strong></strong></strong></pre> <p>和读依赖完成类似,只是写依赖的后面可能跟着多个读依赖,所以需要遍历链表直到发现下 一个写依赖,找到之后如果是读依赖,那么直接并行的运行,如果是写依赖直接运行就好。</p> <p> </p> <p>来自: <a href="/misc/goto?guid=4959674162605750013" rel="nofollow">http://yuyang0.github.io/notes/mxnet-engine.html</a></p> <p> </p>