Skip to content

Commit

Permalink
Style and consistency changes
Browse files Browse the repository at this point in the history
Applying formatting from clang-format.

Inconsistency in Request and RequestSentence due to private variables
resolved by means of changing to class. (#11)
Comments added to request.h, since it holds multiple classes. (#12)
Also takes care of consistent indendation and spacing (#10)
  • Loading branch information
[email protected] committed Jan 11, 2021
1 parent 610c3a3 commit 9fbe3f3
Show file tree
Hide file tree
Showing 13 changed files with 385 additions and 390 deletions.
58 changes: 25 additions & 33 deletions src/bergamot/batch_translator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@ namespace marian {
namespace bergamot {

BatchTranslator::BatchTranslator(DeviceId const device,
PCQueue<PCItem> *pcqueue,
Ptr<Options> options)
: device_(device), options_(options){


PCQueue<PCItem> *pcqueue, Ptr<Options> options)
: device_(device), options_(options) {

ABORT_IF(thread_ != NULL, "Don't call start on a running worker!");
thread_.reset(new std::thread([&]{ this->mainloop(pcqueue); }));

thread_.reset(new std::thread([&] { this->mainloop(pcqueue); }));
}

void BatchTranslator::initGraph(){
void BatchTranslator::initGraph() {
vocabs_ = loadVocabularies(options_);
if (options_->hasAndNotEmpty("shortlist")) {
Ptr<data::ShortlistGenerator const> slgen;
Expand All @@ -27,9 +23,8 @@ void BatchTranslator::initGraph(){
options_, vocabs_.front(), vocabs_.back(), srcIdx, trgIdx, shared_vcb);
}

graph_ = New<ExpressionGraph>(true); // always optimize
auto prec =
options_->get<std::vector<std::string>>("precision", {"float32"});
graph_ = New<ExpressionGraph>(true); // always optimize
auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
graph_->setDefaultElementType(typeFromString(prec[0]));
graph_->setDevice(device_);
graph_->getBackend()->configureDevice(options_);
Expand All @@ -45,8 +40,8 @@ void BatchTranslator::initGraph(){
graph_->forward();
}

void BatchTranslator::translate(const Ptr<Segments> segments,
Histories &histories){
void BatchTranslator::translate(const Ptr<Segments> segments,
Histories &histories) {
int id = 0;
std::vector<data::SentenceTuple> batchVector;
Timer timer;
Expand All @@ -64,9 +59,11 @@ void BatchTranslator::translate(const Ptr<Segments> segments,
std::vector<size_t> sentenceIds;
std::vector<int> maxDims;
for (auto &ex : batchVector) {
if (maxDims.size() < ex.size()) maxDims.resize(ex.size(), 0);
if (maxDims.size() < ex.size())
maxDims.resize(ex.size(), 0);
for (size_t i = 0; i < ex.size(); ++i) {
if (ex[i].size() > (size_t)maxDims[i]) maxDims[i] = (int)ex[i].size();
if (ex[i].size() > (size_t)maxDims[i])
maxDims[i] = (int)ex[i].size();
}
sentenceIds.push_back(ex.getId());
}
Expand All @@ -79,7 +76,6 @@ void BatchTranslator::translate(const Ptr<Segments> segments,
subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
}


PLOG(_identifier(), info, "subBatches created in {}; ", timer.elapsed());
timer.reset();

Expand All @@ -94,55 +90,51 @@ void BatchTranslator::translate(const Ptr<Segments> segments,
}
}

for (size_t j = 0; j < maxDims.size(); ++j)
for (size_t j = 0; j < maxDims.size(); ++j)
subBatches[j]->setWords(words[j]);

auto batch = Ptr<CorpusBatch>(new CorpusBatch(subBatches));
batch->setSentenceIds(sentenceIds);
PLOG(_identifier(), info, "corpusBatch created in {}; ", timer.elapsed());
timer.reset();



auto trgVocab = vocabs_.back();
auto search = New<BeamSearch>(options_, scorers_, trgVocab);

histories = search->search(graph_, batch);
PLOG(_identifier(), info, "BeamSearch completed in {}; ", timer.elapsed());

timer.reset();

}

void BatchTranslator::mainloop(PCQueue<PCItem> *pcqueue){
void BatchTranslator::mainloop(PCQueue<PCItem> *pcqueue) {
initGraph();
while(running_){
while (running_) {
Timer timer;
PCItem pcitem;
pcqueue->Consume(pcitem);
if(pcitem.isPoison()){
running_ = false;
if (pcitem.isPoison()) {
running_ = false;
} else {
PLOG(_identifier(), info, "consumed item in {}; ", timer.elapsed());
timer.reset();
Histories histories;
translate(pcitem.segments, histories);
PLOG(_identifier(), info, "translated item in {}; ", timer.elapsed());
timer.reset();
for(int i=0; i < (pcitem.sentences)->size(); i++){
for (int i = 0; i < (pcitem.sentences)->size(); i++) {
Ptr<History> history = histories.at(i);
Ptr<Request> request = ((pcitem.sentences)->at(i)).request;
int index = ((pcitem.sentences)->at(i)).index;
request->set_translation(index, history);
RequestSentence requestSentence = pcitem.sentences->at(i);
requestSentence.completeSentence(history);
}
}
}
}

void BatchTranslator::join(){
thread_->join();
thread_.reset();
void BatchTranslator::join() {
thread_->join();
thread_.reset();
}

} // namespace bergamot
} // namespace marian
} // namespace bergamot
} // namespace marian
17 changes: 7 additions & 10 deletions src/bergamot/batch_translator.h
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
#ifndef __BERGAMOT_BATCH_TRANSLATOR_H
#define __BERGAMOT_BATCH_TRANSLATOR_H

#include <atomic>
#include <ctime>
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <atomic>

#include "common/logging.h"
#include "common/utils.h"
#include "data/batch_generator.h"
#include "data/corpus.h"
#include "data/shortlist.h"
#include "data/text_input.h"
#include "translator/history.h"
#include "translator/scorers.h"
#include "definitions.h"
#include "translator/beam_search.h"
#include "pcqueue.h"
#include "request.h"
#include "translator/beam_search.h"
#include "translator/history.h"
#include "translator/scorers.h"

#include "sanelogging.h"

Expand All @@ -29,8 +29,7 @@ namespace bergamot {
class BatchTranslator {
public:
BatchTranslator(const BatchTranslator &) = default;
BatchTranslator(DeviceId const device,
PCQueue<PCItem> *pcqueue,
BatchTranslator(DeviceId const device, PCQueue<PCItem> *pcqueue,
Ptr<Options> options);

void initGraph();
Expand All @@ -39,7 +38,6 @@ class BatchTranslator {
std::string _identifier() { return "worker" + std::to_string(device_.no); }
void join();


private:
Ptr<Options> options_;
DeviceId device_;
Expand All @@ -50,9 +48,8 @@ class BatchTranslator {
bool running_{true};

std::unique_ptr<std::thread> thread_;

};
} // namespace bergamot
} // namespace marian
} // namespace bergamot
} // namespace marian

#endif // __BERGAMOT_BATCH_TRANSLATOR_H
33 changes: 16 additions & 17 deletions src/bergamot/batcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@ Batcher::Batcher(Ptr<Options> options) {
max_input_tokens_ = options->get<int>("max-input-tokens");
max_input_sentence_tokens_ = options->get<int>("max-input-sentence-tokens");
bucket.reserve(max_input_sentence_tokens_ + 1);
for(int i=0; i<=max_input_sentence_tokens_; i++){
for (int i = 0; i <= max_input_sentence_tokens_; i++) {
bucket.push_back(std::set<RequestSentence>());
}
}

void Batcher::addSentenceWithPriority(RequestSentence &sentence){
int bucket_id = sentence.num_tokens();
void Batcher::addSentenceWithPriority(RequestSentence &sentence) {
int bucket_id = sentence.numTokens();
assert(bucket_id <= max_input_sentence_tokens_);
bucket[bucket_id].insert(sentence);
}


void Batcher::cleave_batch(Ptr<Segments> segments,
void Batcher::cleave_batch(Ptr<Segments> segments,
Ptr<RequestSentences> sentences) {
/* Temporary stub, needs improvement this section */
int segments_added = 0;
Expand All @@ -32,28 +31,28 @@ void Batcher::cleave_batch(Ptr<Segments> segments,
int prev_padded_batch_size;
for (int i = 0; i < bucket.size(); i++) {
auto p = bucket[i].begin();
while ( p != bucket[i].end() ){
padded_batch_size = (segments_added+1)*i;
if (padded_batch_size < max_input_tokens_){
while (p != bucket[i].end()) {
padded_batch_size = (segments_added + 1) * i;
if (padded_batch_size < max_input_tokens_) {
auto q = p;
current_input_tokens += i;
segments->push_back(q->segment());
Segment segment = q->getUnderlyingSegment();
segments->push_back(std::move(segment));
sentences->push_back(*q);
++p;
++segments_added;
bucket[i].erase(q);
prev_padded_batch_size = padded_batch_size;
}
else{
PLOG("main", info, "New batch generated; {} Segments added;", segments_added);
PLOG("main", info,
"padded_batch_size ({}) current_input_tokens({})",
prev_padded_batch_size, current_input_tokens);
} else {
PLOG("main", info, "New batch generated; {} Segments added;",
segments_added);
PLOG("main", info, "padded_batch_size ({}) current_input_tokens({})",
prev_padded_batch_size, current_input_tokens);
return;
}
}
}
}

} // namespace bergamot
} // namespace marian
} // namespace bergamot
} // namespace marian
10 changes: 4 additions & 6 deletions src/bergamot/batcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
#include "definitions.h"
#include "request.h"

#include <vector>
#include <set>

#include <vector>

namespace marian {
namespace bergamot {
Expand All @@ -19,14 +18,13 @@ class Batcher {
unsigned int max_input_sentence_tokens_;
std::vector<std::set<RequestSentence>> bucket;

public:
public:
explicit Batcher(Ptr<Options> options);
void addSentenceWithPriority(RequestSentence &);
void cleave_batch(Ptr<Segments>, Ptr<RequestSentences>);
};

} // namespace bergamot
} // namespace marian

} // namespace bergamot
} // namespace marian

#endif // __BERGAMOT_BATCHER_H
17 changes: 8 additions & 9 deletions src/bergamot/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

#include "service.h"


int main(int argc, char *argv[]) {
marian::ConfigParser cp(marian::cli::mode::translation);

Expand All @@ -31,8 +30,8 @@ int main(int argc, char *argv[]) {
cp.addOption<std::string>(
"--ssplit-prefix-file", "Server Options",
"File with nonbreaking prefixes for sentence splitting.");
cp.addOption<std::string>(
"--ssplit-mode", "Server Options", "[paragraph, sentence, wrapped_text]");
cp.addOption<std::string>("--ssplit-mode", "Server Options",
"[paragraph, sentence, wrapped_text]");
cp.addOption<std::string>("--source-language", "Server Options",
"source language of translation service");
cp.addOption<std::string>("--target-language", "Server Options",
Expand All @@ -57,8 +56,8 @@ int main(int argc, char *argv[]) {
// std::getline(std::cin, input);
// std::cout << input << "\n";

std::ostringstream std_input;
std_input << std::cin.rdbuf();
std::ostringstream std_input;
std_input << std::cin.rdbuf();
std::string input = std_input.str();

marian::string_view input_view(input);
Expand All @@ -68,10 +67,10 @@ int main(int argc, char *argv[]) {
auto translation_result_future = service.translate(input_view);
translation_result_future.wait();
auto translation_result = translation_result_future.get();
for (int i=0; i < translation_result.sources.size(); i++){
std::cout<< "[src] " << translation_result.sources[i]<<"\n";
std::cout<< "[tgt] " << translation_result.translations[i]<<"\n";
std::cout<< "--------------------------------\n";
for (int i = 0; i < translation_result.sources.size(); i++) {
std::cout << "[src] " << translation_result.sources[i] << "\n";
std::cout << "[tgt] " << translation_result.translations[i] << "\n";
std::cout << "--------------------------------\n";
}

service.stop();
Expand Down
Loading

0 comments on commit 9fbe3f3

Please sign in to comment.