From 9fbe3f33366314e59241619348386b39a8dd6f4f Mon Sep 17 00:00:00 2001 From: "jphilip@ed.ac.uk" Date: Mon, 11 Jan 2021 11:57:21 +0000 Subject: [PATCH] Style and consistency changes Applying formatting from clang-format. Inconsistency in Request and RequestSentence due to private variables resolved by means of changing to class. (#11) Comments added to request.h, since it holds multiple classes. (#12) Also takes care of consistent indendation and spacing (#10) --- src/bergamot/batch_translator.cpp | 58 ++++---- src/bergamot/batch_translator.h | 17 +-- src/bergamot/batcher.cpp | 33 +++-- src/bergamot/batcher.h | 10 +- src/bergamot/main.cpp | 17 ++- src/bergamot/pcqueue.h | 224 +++++++++++++++--------------- src/bergamot/request.cpp | 122 ++++++++-------- src/bergamot/request.h | 115 +++++++++------ src/bergamot/service.cpp | 85 +++++------- src/bergamot/service.h | 23 +-- src/bergamot/textops.cpp | 44 +++--- src/bergamot/textops.h | 19 ++- src/bergamot/translation_result.h | 8 +- 13 files changed, 385 insertions(+), 390 deletions(-) diff --git a/src/bergamot/batch_translator.cpp b/src/bergamot/batch_translator.cpp index 0fa1b45..8a032d7 100644 --- a/src/bergamot/batch_translator.cpp +++ b/src/bergamot/batch_translator.cpp @@ -6,18 +6,14 @@ namespace marian { namespace bergamot { BatchTranslator::BatchTranslator(DeviceId const device, - PCQueue *pcqueue, - Ptr options) - : device_(device), options_(options){ - - + PCQueue *pcqueue, Ptr options) + : device_(device), options_(options) { ABORT_IF(thread_ != NULL, "Don't call start on a running worker!"); - thread_.reset(new std::thread([&]{ this->mainloop(pcqueue); })); - + thread_.reset(new std::thread([&] { this->mainloop(pcqueue); })); } -void BatchTranslator::initGraph(){ +void BatchTranslator::initGraph() { vocabs_ = loadVocabularies(options_); if (options_->hasAndNotEmpty("shortlist")) { Ptr slgen; @@ -27,9 +23,8 @@ void BatchTranslator::initGraph(){ options_, vocabs_.front(), vocabs_.back(), srcIdx, trgIdx, shared_vcb); } - graph_ = New(true); // always optimize - auto prec = - options_->get>("precision", {"float32"}); + graph_ = New(true); // always optimize + auto prec = options_->get>("precision", {"float32"}); graph_->setDefaultElementType(typeFromString(prec[0])); graph_->setDevice(device_); graph_->getBackend()->configureDevice(options_); @@ -45,8 +40,8 @@ void BatchTranslator::initGraph(){ graph_->forward(); } -void BatchTranslator::translate(const Ptr segments, - Histories &histories){ +void BatchTranslator::translate(const Ptr segments, + Histories &histories) { int id = 0; std::vector batchVector; Timer timer; @@ -64,9 +59,11 @@ void BatchTranslator::translate(const Ptr segments, std::vector sentenceIds; std::vector maxDims; for (auto &ex : batchVector) { - if (maxDims.size() < ex.size()) maxDims.resize(ex.size(), 0); + if (maxDims.size() < ex.size()) + maxDims.resize(ex.size(), 0); for (size_t i = 0; i < ex.size(); ++i) { - if (ex[i].size() > (size_t)maxDims[i]) maxDims[i] = (int)ex[i].size(); + if (ex[i].size() > (size_t)maxDims[i]) + maxDims[i] = (int)ex[i].size(); } sentenceIds.push_back(ex.getId()); } @@ -79,7 +76,6 @@ void BatchTranslator::translate(const Ptr segments, subBatches.emplace_back(New(batchSize, maxDims[j], vocabs_[j])); } - PLOG(_identifier(), info, "subBatches created in {}; ", timer.elapsed()); timer.reset(); @@ -94,7 +90,7 @@ void BatchTranslator::translate(const Ptr segments, } } - for (size_t j = 0; j < maxDims.size(); ++j) + for (size_t j = 0; j < maxDims.size(); ++j) subBatches[j]->setWords(words[j]); auto batch = Ptr(new CorpusBatch(subBatches)); @@ -102,8 +98,6 @@ void BatchTranslator::translate(const Ptr segments, PLOG(_identifier(), info, "corpusBatch created in {}; ", timer.elapsed()); timer.reset(); - - auto trgVocab = vocabs_.back(); auto search = New(options_, scorers_, trgVocab); @@ -111,17 +105,16 @@ void BatchTranslator::translate(const Ptr segments, PLOG(_identifier(), info, "BeamSearch completed in {}; ", timer.elapsed()); timer.reset(); - } -void BatchTranslator::mainloop(PCQueue *pcqueue){ +void BatchTranslator::mainloop(PCQueue *pcqueue) { initGraph(); - while(running_){ + while (running_) { Timer timer; PCItem pcitem; pcqueue->Consume(pcitem); - if(pcitem.isPoison()){ - running_ = false; + if (pcitem.isPoison()) { + running_ = false; } else { PLOG(_identifier(), info, "consumed item in {}; ", timer.elapsed()); timer.reset(); @@ -129,20 +122,19 @@ void BatchTranslator::mainloop(PCQueue *pcqueue){ translate(pcitem.segments, histories); PLOG(_identifier(), info, "translated item in {}; ", timer.elapsed()); timer.reset(); - for(int i=0; i < (pcitem.sentences)->size(); i++){ + for (int i = 0; i < (pcitem.sentences)->size(); i++) { Ptr history = histories.at(i); - Ptr request = ((pcitem.sentences)->at(i)).request; - int index = ((pcitem.sentences)->at(i)).index; - request->set_translation(index, history); + RequestSentence requestSentence = pcitem.sentences->at(i); + requestSentence.completeSentence(history); } } } } -void BatchTranslator::join(){ - thread_->join(); - thread_.reset(); +void BatchTranslator::join() { + thread_->join(); + thread_.reset(); } -} // namespace bergamot -} // namespace marian +} // namespace bergamot +} // namespace marian diff --git a/src/bergamot/batch_translator.h b/src/bergamot/batch_translator.h index f7a210e..5234da6 100644 --- a/src/bergamot/batch_translator.h +++ b/src/bergamot/batch_translator.h @@ -1,12 +1,12 @@ #ifndef __BERGAMOT_BATCH_TRANSLATOR_H #define __BERGAMOT_BATCH_TRANSLATOR_H +#include #include #include #include #include #include -#include #include "common/logging.h" #include "common/utils.h" @@ -14,12 +14,12 @@ #include "data/corpus.h" #include "data/shortlist.h" #include "data/text_input.h" -#include "translator/history.h" -#include "translator/scorers.h" #include "definitions.h" -#include "translator/beam_search.h" #include "pcqueue.h" #include "request.h" +#include "translator/beam_search.h" +#include "translator/history.h" +#include "translator/scorers.h" #include "sanelogging.h" @@ -29,8 +29,7 @@ namespace bergamot { class BatchTranslator { public: BatchTranslator(const BatchTranslator &) = default; - BatchTranslator(DeviceId const device, - PCQueue *pcqueue, + BatchTranslator(DeviceId const device, PCQueue *pcqueue, Ptr options); void initGraph(); @@ -39,7 +38,6 @@ class BatchTranslator { std::string _identifier() { return "worker" + std::to_string(device_.no); } void join(); - private: Ptr options_; DeviceId device_; @@ -50,9 +48,8 @@ class BatchTranslator { bool running_{true}; std::unique_ptr thread_; - }; -} // namespace bergamot -} // namespace marian +} // namespace bergamot +} // namespace marian #endif // __BERGAMOT_BATCH_TRANSLATOR_H diff --git a/src/bergamot/batcher.cpp b/src/bergamot/batcher.cpp index cb03ce9..3d6ca17 100644 --- a/src/bergamot/batcher.cpp +++ b/src/bergamot/batcher.cpp @@ -11,19 +11,18 @@ Batcher::Batcher(Ptr options) { max_input_tokens_ = options->get("max-input-tokens"); max_input_sentence_tokens_ = options->get("max-input-sentence-tokens"); bucket.reserve(max_input_sentence_tokens_ + 1); - for(int i=0; i<=max_input_sentence_tokens_; i++){ + for (int i = 0; i <= max_input_sentence_tokens_; i++) { bucket.push_back(std::set()); } } -void Batcher::addSentenceWithPriority(RequestSentence &sentence){ - int bucket_id = sentence.num_tokens(); +void Batcher::addSentenceWithPriority(RequestSentence &sentence) { + int bucket_id = sentence.numTokens(); assert(bucket_id <= max_input_sentence_tokens_); bucket[bucket_id].insert(sentence); } - -void Batcher::cleave_batch(Ptr segments, +void Batcher::cleave_batch(Ptr segments, Ptr sentences) { /* Temporary stub, needs improvement this section */ int segments_added = 0; @@ -32,28 +31,28 @@ void Batcher::cleave_batch(Ptr segments, int prev_padded_batch_size; for (int i = 0; i < bucket.size(); i++) { auto p = bucket[i].begin(); - while ( p != bucket[i].end() ){ - padded_batch_size = (segments_added+1)*i; - if (padded_batch_size < max_input_tokens_){ + while (p != bucket[i].end()) { + padded_batch_size = (segments_added + 1) * i; + if (padded_batch_size < max_input_tokens_) { auto q = p; current_input_tokens += i; - segments->push_back(q->segment()); + Segment segment = q->getUnderlyingSegment(); + segments->push_back(std::move(segment)); sentences->push_back(*q); ++p; ++segments_added; bucket[i].erase(q); prev_padded_batch_size = padded_batch_size; - } - else{ - PLOG("main", info, "New batch generated; {} Segments added;", segments_added); - PLOG("main", info, - "padded_batch_size ({}) current_input_tokens({})", - prev_padded_batch_size, current_input_tokens); + } else { + PLOG("main", info, "New batch generated; {} Segments added;", + segments_added); + PLOG("main", info, "padded_batch_size ({}) current_input_tokens({})", + prev_padded_batch_size, current_input_tokens); return; } } } } -} // namespace bergamot -} // namespace marian +} // namespace bergamot +} // namespace marian diff --git a/src/bergamot/batcher.h b/src/bergamot/batcher.h index 6b8eb66..49f804f 100644 --- a/src/bergamot/batcher.h +++ b/src/bergamot/batcher.h @@ -6,9 +6,8 @@ #include "definitions.h" #include "request.h" -#include #include - +#include namespace marian { namespace bergamot { @@ -19,14 +18,13 @@ class Batcher { unsigned int max_input_sentence_tokens_; std::vector> bucket; - public: +public: explicit Batcher(Ptr options); void addSentenceWithPriority(RequestSentence &); void cleave_batch(Ptr, Ptr); }; -} // namespace bergamot -} // namespace marian - +} // namespace bergamot +} // namespace marian #endif // __BERGAMOT_BATCHER_H diff --git a/src/bergamot/main.cpp b/src/bergamot/main.cpp index d5245b6..508869a 100644 --- a/src/bergamot/main.cpp +++ b/src/bergamot/main.cpp @@ -12,7 +12,6 @@ #include "service.h" - int main(int argc, char *argv[]) { marian::ConfigParser cp(marian::cli::mode::translation); @@ -31,8 +30,8 @@ int main(int argc, char *argv[]) { cp.addOption( "--ssplit-prefix-file", "Server Options", "File with nonbreaking prefixes for sentence splitting."); - cp.addOption( - "--ssplit-mode", "Server Options", "[paragraph, sentence, wrapped_text]"); + cp.addOption("--ssplit-mode", "Server Options", + "[paragraph, sentence, wrapped_text]"); cp.addOption("--source-language", "Server Options", "source language of translation service"); cp.addOption("--target-language", "Server Options", @@ -57,8 +56,8 @@ int main(int argc, char *argv[]) { // std::getline(std::cin, input); // std::cout << input << "\n"; - std::ostringstream std_input; - std_input << std::cin.rdbuf(); + std::ostringstream std_input; + std_input << std::cin.rdbuf(); std::string input = std_input.str(); marian::string_view input_view(input); @@ -68,10 +67,10 @@ int main(int argc, char *argv[]) { auto translation_result_future = service.translate(input_view); translation_result_future.wait(); auto translation_result = translation_result_future.get(); - for (int i=0; i < translation_result.sources.size(); i++){ - std::cout<< "[src] " << translation_result.sources[i]<<"\n"; - std::cout<< "[tgt] " << translation_result.translations[i]<<"\n"; - std::cout<< "--------------------------------\n"; + for (int i = 0; i < translation_result.sources.size(); i++) { + std::cout << "[src] " << translation_result.sources[i] << "\n"; + std::cout << "[tgt] " << translation_result.translations[i] << "\n"; + std::cout << "--------------------------------\n"; } service.stop(); diff --git a/src/bergamot/pcqueue.h b/src/bergamot/pcqueue.h index 52a8f93..8e2d7a6 100644 --- a/src/bergamot/pcqueue.h +++ b/src/bergamot/pcqueue.h @@ -10,10 +10,10 @@ #include #ifdef __APPLE__ +#include +#include #include #include -#include -#include #elif defined(__linux) #include #else @@ -21,13 +21,11 @@ #endif #if __GNUC__ >= 3 -#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0) +#define UTIL_UNLIKELY(x) __builtin_expect(!!(x), 0) #else #define UTIL_UNLIKELY(x) (x) #endif - - namespace marian { namespace bergamot { @@ -37,78 +35,77 @@ namespace bergamot { #ifdef __APPLE__ class Semaphore { - public: - explicit Semaphore(int value) : task_(mach_task_self()) { - ABORT_IF(KERN_SUCCESS != semaphore_create(task_, &back_, SYNC_POLICY_FIFO, value), "Could not create semaphore"); - } +public: + explicit Semaphore(int value) : task_(mach_task_self()) { + ABORT_IF(KERN_SUCCESS != + semaphore_create(task_, &back_, SYNC_POLICY_FIFO, value), + "Could not create semaphore"); + } - ~Semaphore() { - if (KERN_SUCCESS != semaphore_destroy(task_, back_)) { - std::cerr << "Could not destroy semaphore" << std::endl; - abort(); - } + ~Semaphore() { + if (KERN_SUCCESS != semaphore_destroy(task_, back_)) { + std::cerr << "Could not destroy semaphore" << std::endl; + abort(); } + } - void wait() { - ABORT_IF(KERN_SUCCESS != semaphore_wait(back_), Exception, "Wait for semaphore failed"); - } + void wait() { + ABORT_IF(KERN_SUCCESS != semaphore_wait(back_), Exception, + "Wait for semaphore failed"); + } - void post() { - ABORT_IF(KERN_SUCCESS != semaphore_signal(back_), Exception, "Could not post to semaphore"); - } + void post() { + ABORT_IF(KERN_SUCCESS != semaphore_signal(back_), Exception, + "Could not post to semaphore"); + } - private: - semaphore_t back_; - task_t task_; +private: + semaphore_t back_; + task_t task_; }; -inline void WaitSemaphore(Semaphore &semaphore) { - semaphore.wait(); -} +inline void WaitSemaphore(Semaphore &semaphore) { semaphore.wait(); } #elif defined(__linux) class Semaphore { - public: - explicit Semaphore(unsigned int value) { - ABORT_IF(sem_init(&sem_, 0, value), "Could not create semaphore"); - } +public: + explicit Semaphore(unsigned int value) { + ABORT_IF(sem_init(&sem_, 0, value), "Could not create semaphore"); + } - ~Semaphore() { - if (-1 == sem_destroy(&sem_)) { - std::cerr << "Could not destroy semaphore " << std::endl; - abort(); - } + ~Semaphore() { + if (-1 == sem_destroy(&sem_)) { + std::cerr << "Could not destroy semaphore " << std::endl; + abort(); } + } - void wait() { - while (UTIL_UNLIKELY(-1 == sem_wait(&sem_))) { - ABORT_IF(errno != EINTR, "Wait for semaphore failed"); - } + void wait() { + while (UTIL_UNLIKELY(-1 == sem_wait(&sem_))) { + ABORT_IF(errno != EINTR, "Wait for semaphore failed"); } + } - void post() { - ABORT_IF(-1 == sem_post(&sem_), "Could not post to semaphore"); - } + void post() { + ABORT_IF(-1 == sem_post(&sem_), "Could not post to semaphore"); + } - private: - sem_t sem_; +private: + sem_t sem_; }; -inline void WaitSemaphore(Semaphore &semaphore) { - semaphore.wait(); -} +inline void WaitSemaphore(Semaphore &semaphore) { semaphore.wait(); } #else typedef boost::interprocess::interprocess_semaphore Semaphore; -inline void WaitSemaphore (Semaphore &on) { +inline void WaitSemaphore(Semaphore &on) { while (1) { try { on.wait(); break; - } - catch (boost::interprocess::interprocess_exception &e) { + } catch (boost::interprocess::interprocess_exception &e) { if (e.get_native_error() != EINTR) { throw; } @@ -123,16 +120,15 @@ inline void WaitSemaphore (Semaphore &on) { * T must be default constructable and have operator=. * The value is copied twice for Consume(T &out) or three times for Consume(), * so larger objects should be passed via pointer. - * Strong exception guarantee if operator= throws. Undefined if semaphores throw. + * Strong exception guarantee if operator= throws. Undefined if semaphores + * throw. */ template class PCQueue { - public: +public: explicit PCQueue(size_t size) - : empty_(size), used_(0), - storage_(new T[size]), - end_(storage_.get() + size), - produce_at_(storage_.get()), - consume_at_(storage_.get()) {} + : empty_(size), used_(0), storage_(new T[size]), + end_(storage_.get() + size), produce_at_(storage_.get()), + consume_at_(storage_.get()) {} // Add a value to the queue. void Produce(const T &val) { @@ -145,7 +141,8 @@ template class PCQueue { empty_.post(); throw; } - if (++produce_at_ == end_) produce_at_ = storage_.get(); + if (++produce_at_ == end_) + produce_at_ = storage_.get(); } used_.post(); } @@ -161,14 +158,14 @@ template class PCQueue { empty_.post(); throw; } - if (++produce_at_ == end_) produce_at_ = storage_.get(); + if (++produce_at_ == end_) + produce_at_ = storage_.get(); } used_.post(); } - // Consume a value, assigning it to out. - T& Consume(T &out) { + T &Consume(T &out) { WaitSemaphore(used_); { std::lock_guard consume_lock(consume_at_mutex_); @@ -178,14 +175,15 @@ template class PCQueue { used_.post(); throw; } - if (++consume_at_ == end_) consume_at_ = storage_.get(); + if (++consume_at_ == end_) + consume_at_ = storage_.get(); } empty_.post(); return out; } // Consume a value, swapping it to out. - T& ConsumeSwap(T &out) { + T &ConsumeSwap(T &out) { WaitSemaphore(used_); { std::lock_guard consume_lock(consume_at_mutex_); @@ -195,13 +193,13 @@ template class PCQueue { used_.post(); throw; } - if (++consume_at_ == end_) consume_at_ = storage_.get(); + if (++consume_at_ == end_) + consume_at_ = storage_.get(); } empty_.post(); return out; } - // Convenience version of Consume that copies the value to return. // The other version is faster. T Consume() { @@ -210,7 +208,7 @@ template class PCQueue { return ret; } - private: +private: // Number of empty spaces in storage_. Semaphore empty_; // Number of occupied spaces in storage_. @@ -236,67 +234,63 @@ template struct UnboundedPage { }; template class UnboundedSingleQueue { - public: - UnboundedSingleQueue() : valid_(0) { - SetFilling(new UnboundedPage()); - SetReading(filling_); - } +public: + UnboundedSingleQueue() : valid_(0) { + SetFilling(new UnboundedPage()); + SetReading(filling_); + } - void Produce(T &&val) { - if (filling_current_ == filling_end_) { - UnboundedPage *next = new UnboundedPage(); - filling_->next = next; - SetFilling(next); - } - *(filling_current_++) = std::move(val); - valid_.post(); + void Produce(T &&val) { + if (filling_current_ == filling_end_) { + UnboundedPage *next = new UnboundedPage(); + filling_->next = next; + SetFilling(next); } + *(filling_current_++) = std::move(val); + valid_.post(); + } - void Produce(const T &val) { - Produce(T(val)); - } + void Produce(const T &val) { Produce(T(val)); } - T& Consume(T &out) { - WaitSemaphore(valid_); - if (reading_current_ == reading_end_) { - SetReading(reading_->next); - } - out = std::move(*(reading_current_++)); - return out; + T &Consume(T &out) { + WaitSemaphore(valid_); + if (reading_current_ == reading_end_) { + SetReading(reading_->next); } + out = std::move(*(reading_current_++)); + return out; + } - // Warning: very much a no-guarantees race-condition-rich implementation! - // But sufficient for our specific purpose: The single thread that consumes - // is also the only one that checks Empty, and knows that it's racing. - bool Empty() const { - return reading_current_ == filling_current_; - } + // Warning: very much a no-guarantees race-condition-rich implementation! + // But sufficient for our specific purpose: The single thread that consumes + // is also the only one that checks Empty, and knows that it's racing. + bool Empty() const { return reading_current_ == filling_current_; } - private: - void SetFilling(UnboundedPage *to) { - filling_ = to; - filling_current_ = to->entries; - filling_end_ = filling_current_ + sizeof(to->entries) / sizeof(T); - } - void SetReading(UnboundedPage *to) { - reading_.reset(to); - reading_current_ = to->entries; - reading_end_ = reading_current_ + sizeof(to->entries) / sizeof(T); - } +private: + void SetFilling(UnboundedPage *to) { + filling_ = to; + filling_current_ = to->entries; + filling_end_ = filling_current_ + sizeof(to->entries) / sizeof(T); + } + void SetReading(UnboundedPage *to) { + reading_.reset(to); + reading_current_ = to->entries; + reading_end_ = reading_current_ + sizeof(to->entries) / sizeof(T); + } - Semaphore valid_; + Semaphore valid_; - UnboundedPage *filling_; + UnboundedPage *filling_; - std::unique_ptr > reading_; + std::unique_ptr> reading_; - T *filling_current_; - T *filling_end_; - T *reading_current_; - T *reading_end_; + T *filling_current_; + T *filling_end_; + T *reading_current_; + T *reading_end_; - UnboundedSingleQueue(const UnboundedSingleQueue &) = delete; - UnboundedSingleQueue &operator=(const UnboundedSingleQueue &) = delete; + UnboundedSingleQueue(const UnboundedSingleQueue &) = delete; + UnboundedSingleQueue &operator=(const UnboundedSingleQueue &) = delete; }; } // namespace bergamot diff --git a/src/bergamot/request.cpp b/src/bergamot/request.cpp index 24f0ad0..260428e 100644 --- a/src/bergamot/request.cpp +++ b/src/bergamot/request.cpp @@ -1,87 +1,93 @@ -#include "common/logging.h" #include "request.h" + #include "definitions.h" #include "translation_result.h" -#include "sys/time.h" -#include + +#include "common/logging.h" + +#include namespace marian { namespace bergamot { -Request::Request(unsigned int Id, - std::vector> vocabs, - string_view reference, - Ptr segments, +Request::Request(unsigned int Id, std::vector> vocabs, + string_view reference, Ptr segments, Ptr sourceAlignments, Ptr> translationResultPromise) - : Id(Id), - vocabs_(vocabs), reference_(reference), - segments(segments), - sourceAlignments(sourceAlignments), - response_(translationResultPromise), + : Id_(Id), vocabs_(vocabs), reference_(reference), segments_(segments), + sourceAlignments_(sourceAlignments), response_(translationResultPromise), counter_(segments->size()) { - for(int i=0; i < segments->size(); i++){ - histories_.push_back(nullptr); - } + // Set vector> to nullptr. + for (int i = 0; i < segments_->size(); i++) { + histories_.push_back(nullptr); + } +} + +int Request::numSegments() const { return segments_->size(); } +int Request::segmentTokens(int index) const { + return (segments_->at(index)).size(); } -int Request::size(){ - return segments->size(); -}; +Segment Request::getSegment(int index) const { return segments_->at(index); } -void Request::set_translation(int index, Ptr history) { - /* This can be accessed by multiple batch_translators at once. */ - // std::lock_guard request_lock(update_mutex_); +void Request::processHistory(int index, Ptr history) { + // Concurrently called by multiple workers as a history from translation is + // ready. The container storing histories is set with the value obtained. histories_[index] = history; - if(--counter_ == 0){ - TranslationResult translation_result; - for(int i=0; i < segments->size(); i++){ - translation_result.sources.push_back( - vocabs_.front()->decode(segments->at(i)) - ); - - history = histories_[i]; - NBestList onebest = history->nBest(1); - Result result = onebest[0]; // Expecting only one result; - Words words = std::get<0>(result); - translation_result.translations.push_back( - vocabs_.back()->decode(words) - ); - - } - LOG(info, "Last translation in. Closing request;"); - response_->set_value(translation_result); + + // In case this is last request in, completeRequest is called, which sets the + // value of the promise. + if (--counter_ == 0) { + completeRequest(); } } -bool operator<(const Request &a, const Request &b) { - // TODO(jerin): Probably enhance - return a.Id < b.Id; +void Request::completeRequest() { + TranslationResult translation_result; + for (int i = 0; i < segments_->size(); i++) { + std::string source = vocabs_.front()->decode(getSegment(i)); + translation_result.sources.push_back(source); + + Ptr history = histories_[i]; + NBestList onebest = history->nBest(1); + Result result = onebest[0]; // Expecting only one result; + Words words = std::get<0>(result); + std::string decoded = vocabs_.back()->decode(words); + translation_result.translations.push_back(decoded); + } + LOG(info, "Last translation in. Closing request;"); + response_->set_value(translation_result); } -RequestSentence::RequestSentence(int index, - Ptr request) - : index(index), request(request) {} +bool Request::operator<(const Request &b) const { + // Among Requests, only sequence id is used for obtaining priority. + return Id_ < b.Id_; +} + +RequestSentence::RequestSentence(int index, Ptr request) + : index_(index), request_(request) {} + +int RequestSentence::numTokens() { return (request_->segmentTokens(index_)); } -int RequestSentence::num_tokens(){ - return (request->segments->at(index).size()); +void RequestSentence::completeSentence(Ptr history) { + // Relays completeSentence into request's processHistory, using index + // information. + request_->processHistory(index_, history); } -Segment RequestSentence::segment() const { - return request->segments->at(index); +Segment RequestSentence::getUnderlyingSegment() const { + return request_->getSegment(index_); } -bool operator<(const RequestSentence& a, const RequestSentence& b) { - if(a.request == b.request){ - return a.index < b.index; +bool operator<(const RequestSentence &a, const RequestSentence &b) { + // Operator overload for usage in priority-queue / set. + if (a.request_ == b.request_) { + return a.index_ < b.index_; } - return a < b; + return a.request_ < b.request_; } -bool operator==(const RequestSentence& a, const RequestSentence& b) { - return (a.index == b.index) and (a.request == b.request); -}; -} // namespace bergamot -} // namespace marian +} // namespace bergamot +} // namespace marian diff --git a/src/bergamot/request.h b/src/bergamot/request.h index 3fbd5eb..f5f0311 100644 --- a/src/bergamot/request.h +++ b/src/bergamot/request.h @@ -1,79 +1,104 @@ -#ifndef __BERGAMOT_REQUEST_H -#define __BERGAMOT_REQUEST_H +// +// Defines: +// +// Request: holds the input blob of a text, Segments (vector) which are +// to go to the batching mechanism and alignments between the processed +// segments and the input blob (sourceAlignments). In addition, Request takes +// care of the barrier which fires when all the Segments in a request are done +// translating by the workers (BatchTranslator). Request is to be extended with +// notions of Priority (sequence, user-given). +// +// RequestSentence: is a mapping of (index, Request*). This provides the +// batching mechanism access to the segment within the request. The backref to +// Request allows event triggering the barrier upon completion of the last +// sentence by a worker. +// +// PCItem: is a vector of RequestSentences and the corresponding Segments, +// which is what the ProducerConsumer queue holds. Can probably get rid of +// Segment here and use RequestSentence directly to construct batches. +// Separation is worker(BatchTranslator) need not be aware of the notion of +// Request, but only a Batch of segments. + +#ifndef SRC_BERGAMOT_REQUEST_H_ +#define SRC_BERGAMOT_REQUEST_H_ -#include "data/types.h" -#include "sys/time.h" -#include "translation_result.h" #include "definitions.h" -#include -#include "translator/beam_search.h" -#include -#include +#include "translation_result.h" +#include "data/types.h" +#include "translator/beam_search.h" +#include +#include namespace marian { namespace bergamot { -struct Request { +class Request { +private: + unsigned int Id_; string_view reference_; - Ptr segments; - Ptr sourceAlignments; + Ptr segments_; + Ptr sourceAlignments_; Ptr> response_; - unsigned int Id; std::vector> histories_; std::atomic counter_; - // @TODO(jerin): This is a bit weird, need to do better. - std::vector> vocabs_; + std::vector> vocabs_; - Request(unsigned int, - std::vector>, - string_view, - Ptr, - Ptr, +public: + Request(unsigned int, std::vector>, string_view, + Ptr, Ptr, Ptr>); - void set_translation(int index, Ptr); - int size(); -}; -struct RequestSentence { - /* A sentence tied to a request. */ - int index; - Ptr request; - RequestSentence(int, Ptr); - int num_tokens(); - Segment segment() const; + void processHistory(int index, Ptr); + void completeRequest(); + + // Obtain the count of tokens in a segment. Used to insert sentence from + // multiple requests into the corresponding size bucket. + int segmentTokens(int) const; + // Obtain number of segments in a request. + int numSegments() const; + + // Obtains a segment to create a batch of segments among several requests. + Segment getSegment(int) const; + + bool operator<(const Request &) const; }; -bool operator<(const RequestSentence& a, const RequestSentence& b); -bool operator==(const RequestSentence& a, const RequestSentence& b); +class RequestSentence { +private: + int index_; + Ptr request_; + +public: + RequestSentence(int, Ptr); + int numTokens(); + Segment getUnderlyingSegment() const; + void completeSentence(Ptr); + friend bool operator<(const RequestSentence &, const RequestSentence &); +}; typedef std::vector RequestSentences; struct PCItem { Ptr segments; Ptr sentences; - PCItem(): segments(NULL), sentences(NULL) {} - PCItem(Ptr segments, Ptr sentences): - segments(segments), sentences(sentences){} + PCItem() : segments(NULL), sentences(NULL) {} + PCItem(Ptr segments, Ptr sentences) + : segments(segments), sentences(sentences) {} - void operator=(const PCItem &b){ + void operator=(const PCItem &b) { segments = b.segments; sentences = b.sentences; } - bool isPoison(){ - return (segments == NULL); - } - + bool isPoison() { return (segments == NULL); } }; +} // namespace bergamot +} // namespace marian - -} // namespace bergamot -} // namespace marian - -#endif // __BERGAMOT_REQUEST_H +#endif // SRC_BERGAMOT_REQUEST_H_ diff --git a/src/bergamot/service.cpp b/src/bergamot/service.cpp index 7cd282d..3a439b8 100644 --- a/src/bergamot/service.cpp +++ b/src/bergamot/service.cpp @@ -1,15 +1,14 @@ #include "service.h" + #include "utils.h" -#include #include +#include namespace marian { namespace bergamot { -Service::Service(Ptr options) - : text_processor_(options), - batcher_(options), - running_(true), +Service::Service(Ptr options) + : text_processor_(options), batcher_(options), running_(true), requestId_(0) { int num_workers = options->get("cpu-threads"); @@ -19,48 +18,42 @@ Service::Service(Ptr options) // @TODO(jerin): Fix hardcode, 100*num_workers // @TODO(jerin): make_unique or UNew instead - pcqueue_ = UPtr>(new PCQueue(100*num_workers)); + pcqueue_ = UPtr>(new PCQueue(100 * num_workers)); workers_.reserve(num_workers); - for(int i=0; i < num_workers; i++){ + for (int i = 0; i < num_workers; i++) { marian::DeviceId deviceId(i, DeviceType::cpu); - UPtr batch_translator - = UPtr(new BatchTranslator(deviceId, pcqueue_.get(), options)); + + UPtr batch_translator = UPtr( + new BatchTranslator(deviceId, pcqueue_.get(), options)); + workers_.push_back(std::move(batch_translator)); } } std::future Service::queue(const string_view &input) { // @TODO(jerin): Place a queue to keep track of requests here. - Ptr segments = New(); Ptr sourceAlignments = New(); text_processor_.query_to_segments(input, segments, sourceAlignments); - for(auto &segment: *segments){ - PLOG("main", info, "[token-size {}]", ((int)segment.size())); - } - - Ptr> - translationResultPromise = New>(); + Ptr> translationResultPromise = + New>(); auto future = translationResultPromise->get_future(); - Ptr request = New(requestId_++, - vocabs_, - input, - std::move(segments), - std::move(sourceAlignments), - translationResultPromise); + Ptr request = + New(requestId_++, vocabs_, input, std::move(segments), + std::move(sourceAlignments), translationResultPromise); - for (int i = 0; i < request->size(); i++) { + for (int i = 0; i < request->numSegments(); i++) { RequestSentence requestSentence(i, request); batcher_.addSentenceWithPriority(requestSentence); } /* Cleave batch, run translation */ - Ptr batchSegments; + Ptr batchSegments; Ptr batchSentences; int counter = 0; @@ -70,43 +63,39 @@ std::future Service::queue(const string_view &input) { batchSentences = New>(); batcher_.cleave_batch(batchSegments, batchSentences); - if(batchSegments->size() > 0){ + if (batchSegments->size() > 0) { PCItem pcitem(batchSegments, batchSentences); pcqueue_->Produce(pcitem); ++counter; PLOG("main", info, "Batch {} generated", counter); } } while (batchSegments->size() > 0); - - return future; + return future; } -std::future -Service::translate(const string_view &input) { +std::future Service::translate(const string_view &input) { return queue(input); } -void Service::stop(){ - if(running_){ - int counter = 0; - for(auto &worker: workers_){ - PCItem pcitem; - pcqueue_->Produce(pcitem); - PLOG("main", info, "Adding poison {}", counter); - ++counter; - } - for(auto &worker: workers_){ - worker->join(); - } - - running_ = false; +void Service::stop() { + if (running_) { + int counter = 0; + for (auto &worker : workers_) { + PCItem pcitem; + pcqueue_->Produce(pcitem); + PLOG("main", info, "Adding poison {}", counter); + ++counter; + } + for (auto &worker : workers_) { + worker->join(); + } + + running_ = false; } } -Service::~Service(){ - stop(); -} +Service::~Service() { stop(); } -} // namespace bergamot -} // namespace marian +} // namespace bergamot +} // namespace marian diff --git a/src/bergamot/service.h b/src/bergamot/service.h index a63f63e..83c1a97 100644 --- a/src/bergamot/service.h +++ b/src/bergamot/service.h @@ -1,18 +1,19 @@ -#ifndef __SERVICE_H -#define __SERVICE_H +#ifndef SRC_BERGAMOT_SERVICE_H_ +#define SRC_BERGAMOT_SERVICE_H_ -#include "data/types.h" -#include "translation_result.h" - -#include "textops.h" -#include "batcher.h" #include "batch_translator.h" +#include "batcher.h" #include "pcqueue.h" +#include "textops.h" +#include "translation_result.h" + +#include + +#include "data/types.h" namespace marian { namespace bergamot { - class Service { public: explicit Service(Ptr); @@ -31,7 +32,7 @@ class Service { unsigned int requestId_; }; -} // namespace bergamot -} // namespace marian +} // namespace bergamot +} // namespace marian -#endif // __SERVICE_H +#endif // SRC_BERGAMOT_SERVICE_H_ diff --git a/src/bergamot/textops.cpp b/src/bergamot/textops.cpp index a605889..55656e7 100644 --- a/src/bergamot/textops.cpp +++ b/src/bergamot/textops.cpp @@ -1,37 +1,32 @@ -#include "utils.h" #include "textops.h" +#include "utils.h" #include #include -#include #include #include - +#include namespace marian { namespace bergamot { - SentenceSplitter::SentenceSplitter(marian::Ptr options) : options_(options) { std::string smode_str = options_->get("ssplit-mode", ""); mode_ = string2splitmode(smode_str); - std::string ssplit_prefix_file = options_->get - ("ssplit-prefix-file", ""); + std::string ssplit_prefix_file = + options_->get("ssplit-prefix-file", ""); if (ssplit_prefix_file.size()) { - ssplit_prefix_file - = marian::cli::interpolateEnvVars(ssplit_prefix_file); + ssplit_prefix_file = marian::cli::interpolateEnvVars(ssplit_prefix_file); - LOG(info, - "Loading protected prefixes for sentence splitting from {}", + LOG(info, "Loading protected prefixes for sentence splitting from {}", ssplit_prefix_file); ssplit_.load(ssplit_prefix_file); } else { - LOG(warn, - "Missing list of protected prefixes for sentence splitting. " - "Set with --ssplit-prefix-file."); + LOG(warn, "Missing list of protected prefixes for sentence splitting. " + "Set with --ssplit-prefix-file."); } } @@ -55,25 +50,22 @@ SentenceSplitter::string2splitmode(const std::string &m) { return splitmode::wrapped_text; } -Tokenizer::Tokenizer(Ptr options): - inference_(true), addEOS_(false) { +Tokenizer::Tokenizer(Ptr options) : inference_(true), addEOS_(false) { vocabs_ = loadVocabularies(options); } - -Segment Tokenizer::tokenize(string_view const &snt, +Segment Tokenizer::tokenize(string_view const &snt, SourceAlignment &sourceAlignment) { // TODO(jerin): Bunch of hardcode here, 1, 0, need to get rid off somehow. - return vocabs_[0]->encodePreservingSource(snt, - sourceAlignment, - addEOS_, + return vocabs_[0]->encodePreservingSource(snt, sourceAlignment, addEOS_, inference_); } TextProcessor::TextProcessor(Ptr options) : tokenizer_(options), sentence_splitter_(options) { max_input_sentence_tokens_ = options->get("max-input-sentence-tokens"); - max_input_sentence_tokens_ = max_input_sentence_tokens_ - 1; // Account for EOS + max_input_sentence_tokens_ = + max_input_sentence_tokens_ - 1; // Account for EOS // Dirty assert, should do at configparse assert(max_input_sentence_tokens_ > 0); } @@ -88,8 +80,8 @@ void TextProcessor::query_to_segments(const string_view &query, LOG(trace, "SNT: {}", snt); string_view snt_string_view(snt.data(), snt.size()); SourceAlignment snt_alignment; - Segment tokenized_sentence - = tokenizer_.tokenize(snt_string_view, snt_alignment); + Segment tokenized_sentence = + tokenizer_.tokenize(snt_string_view, snt_alignment); if (tokenized_sentence.size() > max_input_sentence_tokens_) { int offset; @@ -102,7 +94,7 @@ void TextProcessor::query_to_segments(const string_view &query, segments->push_back(segment); auto astart = snt_alignment.begin() + offset; - SourceAlignment segment_alignment(astart, astart+offset); + SourceAlignment segment_alignment(astart, astart + offset); sourceAlignments->push_back(segment_alignment); } @@ -125,5 +117,5 @@ void TextProcessor::query_to_segments(const string_view &query, } } -} // namespace bergamot -} // namespace marian +} // namespace bergamot +} // namespace marian diff --git a/src/bergamot/textops.h b/src/bergamot/textops.h index 2568612..69e35b6 100644 --- a/src/bergamot/textops.h +++ b/src/bergamot/textops.h @@ -2,14 +2,14 @@ #define __BERGAMOT_TEXTOPS_H #include "common/definitions.h" -#include "common/options.h" -#include "ssplit/ssplit.h" #include "common/logging.h" -#include "common/types.h" // missing in shortlist.h +#include "common/options.h" +#include "common/types.h" // missing in shortlist.h #include "common/utils.h" -#include "data/shortlist.h" #include "data/sentencepiece_vocab.h" +#include "data/shortlist.h" #include "definitions.h" +#include "ssplit/ssplit.h" #include #include @@ -17,7 +17,6 @@ namespace marian { namespace bergamot { - class SentenceSplitter { public: explicit SentenceSplitter(Ptr options); @@ -31,7 +30,6 @@ class SentenceSplitter { }; class Tokenizer { - public: std::vector> vocabs_; bool inference_; @@ -41,17 +39,16 @@ class Tokenizer { }; class TextProcessor { - public: +public: Tokenizer tokenizer_; unsigned int max_input_sentence_tokens_; SentenceSplitter sentence_splitter_; explicit TextProcessor(Ptr); - void query_to_segments(const string_view &query, - Ptr, + void query_to_segments(const string_view &query, Ptr, Ptr); }; -} // namespace bergamot -} // namespace marian +} // namespace bergamot +} // namespace marian #endif // __BERGAMOT_TEXTOPS_H diff --git a/src/bergamot/translation_result.h b/src/bergamot/translation_result.h index e49a7a3..426e03c 100644 --- a/src/bergamot/translation_result.h +++ b/src/bergamot/translation_result.h @@ -1,4 +1,8 @@ -#pragma once +#ifndef SRC_BERGAMOT_TRANSLATION_RESULT_H_ +#define SRC_BERGAMOT_TRANSLATION_RESULT_H_ + +#include +#include namespace marian { namespace bergamot { @@ -11,3 +15,5 @@ struct TranslationResult { }; } // namespace bergamot } // namespace marian + +#endif // SRC_BERGAMOT_TRANSLATION_RESULT_H_