global _pg_group_ranks global _backend global _default_pg_init_method
if store isnotNone: assert world_size > 0, 'world_size must be positive if using store' assert rank >= 0, 'rank must be non-negative if using store' elif init_method isNone: init_method = "env://"
backend = Backend(backend)
if backend == Backend.MPI: default_pg = _new_process_group_helper( -1, -1, [], Backend.MPI, None, group_name=group_name, timeout=timeout) _update_default_pg(default_pg) else: # backward compatible API if store isNone: # 如果没有store,还是要用init_method构建一个store。 rendezvous_iterator = rendezvous( init_method, rank, world_size, timeout=timeout ) store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout)
defregister_rendezvous_handler(scheme, handler): """Registers a new rendezvous handler. Args: scheme (str): URL scheme to identify your rendezvous handler. handler (function): Handler that is invoked when the `rendezvous()` function is called with a URL that uses the corresponding scheme. It must be a generator function that yields the triplet. """ global _rendezvous_handlers if scheme in _rendezvous_handlers: raise RuntimeError( "Rendezvous handler for {}:// already registered".format(scheme) ) _rendezvous_handlers[scheme] = handler
if use_torchelastic_store == str(True): worker_process_prefix = "/worker" # When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed # to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread # on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False tcp_store = TCPStore(master_addr, master_port, world_size, False, timeout) yield (PrefixStore(worker_process_prefix, tcp_store), rank, world_size) else: # Start the TCP store daemon on the rank 0 start_daemon = rank == 0 store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout) yield (store, rank, world_size)
# If this configuration is invalidated, there is nothing we can do about it raise RuntimeError("Unable to perform rerendezvous using env:// method")
# Cached process groups # For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store) # For MPI pg, it is a map from ProcessGroup to (Backend, None) _pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {} # Process group's names, map from ProcessGroup to str _pg_names: Dict[ProcessGroup, str] = {} # Process group's global rank to local rank mapping _pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
_new_process_group_helper 之中得到了 store 参数之后,据此生成了一个 prefix_store,然后再根据这个 pre_store 来生成了 ProcessGroupGloo。_new_process_group_helper 代码具体如下:
def_new_process_group_helper(world_size, rank, group_ranks, backend, store, pg_options=None, group_name=None, timeout=default_pg_timeout): """ Create a new distributed process group. This function must be called by ALL processes in the global group, even if the calling process is not part of the newly created group. In that case, this function returns GroupMember.NON_GROUP_MEMBER. This function is called with ``group_ranks == []`` for the default group. """ global _pg_map global _group_count global _pg_names
# If this is a subgroup (which means group_ranks is specified), # we check if the current process is a member of the new group. ifnot is_default_group: global_rank = _get_default_group().rank() if global_rank notin group_ranks: return GroupMember.NON_GROUP_MEMBER
# Use the group name as prefix in the default store, such that # a single store can be reused by multiple groups.
contexts_.reserve(options->devices.size()); for (size_t i = 0; i < options->devices.size(); i++) { auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_); // 又生成了一个PrefixStore auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_); context->setTimeout(options->timeout); // 利用 PrefixStore 建立网络 context->connectFullMesh(store, options->devices[i]); contexts_.push_back(std::move(context)); }
// Every worker thread stores the AsyncWork object it's currently // working on in the workInProgress_ vector. It must have size equal // to the number of workers such that they can simply index into it // using the worker index they are started with. workInProgress_.resize(options->threads);
threads_.resize(options->threads); for (size_t i = 0; i < threads_.size(); i++) { threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i); } }
在下面代码之中,也有对store_的使用,比如等待,存取。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
voidProcessGroupGloo::setSequenceNumberForGroup(){ if (rank_ == 0) { // Create and broadcast sequence number auto seq = 1 + rand(); sequenceNum_ = c10d::SequenceNum(seq); std::vector<char> values = c10d::toVec<char>(seq, kBytes); store_->set(kSeqNumStoreKey, values); // 存value } else { // Read rank 0's sequence number from store. sequenceNum_ = c10d::SequenceNum(); store_->wait({kSeqNumStoreKey}, options_->timeout); // 等待 std::vector<char> values = store_->get(kSeqNumStoreKey); // 取value uint64_t num = c10d::fromVec<char>(values); sequenceNum_->set(num); } }
import torch.distributed as dist from datetime import timedelta # Run on process 1 (server) server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30)) # Run on process 2 (client) client_store = dist.TCPStore("127.0.0.1", 1234, 2, False) # Use any of the store methods from either the client or server after initialization server_store.set("first_key", "first_value") client_store.get("first_key")
或者
1 2 3 4 5 6 7
>>> import torch.distributed as dist >>> from datetime import timedelta >>> # Using TCPStore as an example, other store types can also be used >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) >>> # This will throw an exception after 10 seconds >>> store.wait(["bad_key"], timedelta(seconds=10))
classTCPStore(Store): def__init__(self, host_name, port, world_size=-1, is_master=False, timeout=None, *args, **kwargs): # real signature unknown; NOTE: unreliably restored from __doc__ pass
host = property(lambda self: object(), lambda self, v: None, lambda self: None) # default """Gets the hostname on which the store listens for requests."""
port = property(lambda self: object(), lambda self, v: None, lambda self: None) # default """Gets the port number on which the store listens for requests."""
// NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe // watchKey() is a blocking operation. It will register the socket on // TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will // return once it has verified the callback is registered on both background // threads. Only one thread can call watchKey() at a time. voidwatchKey(const std::string& key, WatchKeyCallback callback)override; boolcheck(const std::vector<std::string>& keys)override; int64_tgetNumKeys()override; voidwait(const std::vector<std::string>& keys)override; voidwait( const std::vector<std::string>& keys, const std::chrono::milliseconds& timeout)override; // Waits for all workers to join. voidwaitForWorkers(); // Returns the hostname used by the TCPStore. const std::string& getHost()constnoexcept; // Returns the port used by the TCPStore. PortType getPort()constnoexcept;
storeSocket_ 的作用是封装面对 master port 的操作,用户只管 set,get 等操作,不用知道 master port。
set(key, data) 的作用就是通过 storeSocket_ 向master 发送一个设置key : value 的请求。
tcpStoreMasterDaemon_ 监听到socket变化,就开始相应。
tcpStoreMasterDaemon_ 内部把 key : value 添加到 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_ 之上。
listenSocket_ 在 tcpStoreWorkerDaemon_ 之上,也连接到 masterListenSocket_: masterPort 之上。下面有一个解耦,如注释所述,It will register the socket on TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon。
voidTCPStore::waitForWorkers(){ addHelper_(initKey_, 1); // Let server block until all workers have completed, this ensures that // the server daemon thread is always running until the very end if (isServer_) { constauto start = std::chrono::steady_clock::now(); while (true) { std::vector<uint8_t> value = getHelper_(initKey_); auto buf = reinterpret_cast<constchar*>(value.data()); auto len = value.size(); int numWorkersCompleted = std::stoi(std::string(buf, len)); if (numWorkersCompleted >= numWorkers_.value_or(-1)) { break; } constauto elapsed = std::chrono::duration_cast<std::chrono::seconds>( std::chrono::steady_clock::now() - start); if (timeout_ != kNoTimeout && elapsed > timeout_) { break; } /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); } } }
// Separate thread that is launched on all instances (including master) // Right now only handles callbacks registered from watchKey() classTCPStoreWorkerDaemon : public BackgroundThread { public: explicitTCPStoreWorkerDaemon(int listenSocket); // Set the callback to run key change voidsetCallback(std::string key, WatchKeyCallback cb); voidwaitForCallbackRegistration(){ // Block until callback has been registered successfully std::unique_lock<std::mutex> callbackRegistrationLock( callbackRegistrationMutex_); callbackRegisteredCV_.wait( callbackRegistrationLock, [&] { return callbackRegisteredData_; });
// Reset payload for next callback callbackRegisteredData_ = false; } voidsetCallbackRegistered(){ callbackRegisteredData_ = true; callbackRegisteredCV_.notify_one(); }
private: voidrun(); voidcallbackHandler(int socket); // List of callbacks map each watched key std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_; std::mutex keyToCallbacksMutex_; std::mutex callbackRegistrationMutex_; std::condition_variable callbackRegisteredCV_; bool callbackRegisteredData_ = false; };
voidTCPStore::watchKey(const std::string& key, WatchKeyCallback callback){ // Only allow one thread to perform watchKey() at a time const std::lock_guard<std::mutex> watchKeyLock(watchKeyMutex_);
// Register callback with TCPStoreMasterDaemon to call TCPStoreWorkerDaemon on // key change std::string regKey = regularPrefix_ + key; tcpStoreWorkerDaemon_->setCallback(regKey, callback); tcputil::sendValue<QueryType>(listenSocket_, QueryType::WATCH_KEY); tcputil::sendString(listenSocket_, regKey);
// Block until callback has been registered successfully tcpStoreWorkerDaemon_->waitForCallbackRegistration(); }
while (true) { // Check control and exit early if triggered int res; SYSCHECK_ERR_RETURN_NEG1( res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count())) if (res == 0) { auto rvPoll = WaitForSingleObject(ghStopEvent_, 0); if (rvPoll != WAIT_TIMEOUT) { break; } continue; }
// if connection is closed gracefully by master, peeked data will return 0 char data; int ret = recv(fds[0].fd, &data, 1, MSG_PEEK); if (ret == 0) { auto rvData = WaitForSingleObject(ghStopEvent_, 0); if (rvData != WAIT_TIMEOUT) { break; } continue; }
while (true) { SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
// Check control and exit early if triggered // The pipe receives an event which tells us to shutdown the listener thread if (fds[0].revents != 0) { // Will be POLLUP when the pipe is closed if (fds[0].revents ^ POLLHUP) { throw std::system_error( ECONNABORTED, std::system_category(), "Unexpected poll revent on the control pipe's reading fd: " + std::to_string(fds[0].revents)); } break; }
// if connection is closed gracefully by master, peeked data will return 0 char data; int ret = recv(fds[1].fd, &data, 1, MSG_PEEK); if (ret == 0) { continue; }
// Separate thread that is only launched on master classTCPStoreMasterDaemon : public BackgroundThread { public: explicitTCPStoreMasterDaemon(int storeListenSocket);
// The master runs on a single thread so only // one handler can be executed at a time voidsetHandler(int socket); voidcompareSetHandler(int socket); voidaddHandler(int socket); voidgetHandler(int socket)const; voidcheckHandler(int socket)const; voidgetNumKeysHandler(int socket)const; voiddeleteHandler(int socket); voidwaitHandler(int socket); voidwatchHandler(int socket);
boolcheckKeys(const std::vector<std::string>& keys)const; // Helper function to alerts waiting workers, used in setHandler, getHandler voidwakeupWaitingClients(const std::string& key); // Helper function used when the key is changed // used in setHandler, addHandler, getHandler, deleteHandler voidsendKeyUpdatesToClients( const std::string& key, constenum WatchResponseType& type, std::vector<uint8_t>& oldData, std::vector<uint8_t>& newData); std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_; // From key -> the list of sockets waiting on the key std::unordered_map<std::string, std::vector<int>> waitingSockets_; // From socket -> number of keys awaited std::unordered_map<int, size_t> keysAwaited_; // From key -> the list of sockets watching the key std::unordered_map<std::string, std::vector<int>> watchedSockets_; };
// receive the queries bool finished = false; while (!finished) { for (size_t i = 0; i < sockets_.size(); i++) { fds[i].revents = 0; }
int res; SYSCHECK_ERR_RETURN_NEG1( res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count())) if (res == 0) { auto rv = WaitForSingleObject(ghStopEvent_, 0); if (rv != WAIT_TIMEOUT) { finished = true; break; } continue; }
// TCPStore's listening socket has an event and it should now be able to // accept new connections. if (fds[0].revents != 0) { // 收到了消息 if (!(fds[0].revents & POLLIN)) { throw std::system_error( ECONNABORTED, std::system_category(), "Unexpected poll revent on the master's listening socket: " + std::to_string(fds[0].revents)); } int sockFd = std::get<0>(tcputil::accept(storeListenSocket_)); sockets_.push_back(sockFd); tcputil::addPollfd(fds, sockFd, POLLIN); } queryFds(fds); // 业务处理 } } #elsevoid TCPStoreMasterDaemon::run() { std::vector<struct pollfd> fds; tcputil::addPollfd(fds, storeListenSocket_, POLLIN); // Push the read end of the pipe to signal the stopping of the daemon run tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);
// receive the queries bool finished = false; while (!finished) { for (size_t i = 0; i < sockets_.size(); i++) { fds[i].revents = 0; }
// TCPStore's listening socket has an event and it should now be able to // accept new connections. if (fds[0].revents != 0) { if (fds[0].revents ^ POLLIN) { throw std::system_error( ECONNABORTED, std::system_category(), "Unexpected poll revent on the master's listening socket: " + std::to_string(fds[0].revents)); } int sockFd = std::get<0>(tcputil::accept(storeListenSocket_)); sockets_.push_back(sockFd); tcputil::addPollfd(fds, sockFd, POLLIN); }
// The pipe receives an event which tells us to shutdown the daemon if (fds[1].revents != 0) { // 收到了消息 // Will be POLLUP when the pipe is closed if (fds[1].revents ^ POLLHUP) { throw std::system_error( ECONNABORTED, std::system_category(), "Unexpected poll revent on the control pipe's reading fd: " + std::to_string(fds[1].revents)); } finished = true; break; } queryFds(fds); // 业务处理 } } #endif
voidTCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds){ // Skipping the fds[0] and fds[1], // fds[0] is master's listening socket // fds[1] is control pipe's reading fd, it is not for Windows platform for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) { if (fds[fdIdx].revents == 0) { continue; }
// Now query the socket that has the event try { query(fds[fdIdx].fd); // 处理业务 } catch (...) { tcputil::closeSocket(fds[fdIdx].fd);
// Remove all the tracking state of the close FD for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) { for (auto vecIt = it->second.begin(); vecIt != it->second.end();) { if (*vecIt == fds[fdIdx].fd) { vecIt = it->second.erase(vecIt); } else { ++vecIt; } } if (it->second.size() == 0) { it = waitingSockets_.erase(it); } else { ++it; } } for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) { if (it->first == fds[fdIdx].fd) { it = keysAwaited_.erase(it); } else { ++it; } } fds.erase(fds.begin() + fdIdx); sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET); --fdIdx; continue; } } }
// query communicates with the worker. The format // of the query is as follows: // type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... // or, in the case of wait // type of query | number of args | size of arg1 | arg1 | ... voidTCPStoreMasterDaemon::query(int socket){ QueryType qt; tcputil::recvBytes<QueryType>(socket, &qt, 1); if (qt == QueryType::SET) { setHandler(socket);