Skip to content

Commit

Permalink
Merge pull request #212 from torbjoernk/feature/fix-mpi
Browse files Browse the repository at this point in the history
partway through fixing hangs
  • Loading branch information
memmett committed May 15, 2015
2 parents 9e3893e + b32ccc3 commit 44e6d86
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 89 deletions.
1 change: 1 addition & 0 deletions include/pfasst/controller/pfasst.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ namespace pfasst
* @param[in] level_iter level iterator providing information to compute the communication tag
*/
virtual int tag(LevelIter level_iter);
virtual int stag(LevelIter level_iter);

/**
* Post current status and values to next processor.
Expand Down
2 changes: 1 addition & 1 deletion include/pfasst/encap/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace pfasst
*/
VectorEncapsulation(Encapsulation<time>&& other);

virtual ~VectorEncapsulation() = default;
virtual ~VectorEncapsulation();
//! @}

//! @{
Expand Down
14 changes: 9 additions & 5 deletions include/pfasst/interfaces.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ namespace pfasst
*/
class IStatus
{
public:
static const int NOT_CONVERGED = 0;
static const int CONVERGED = 1;

protected:
ICommunicator* comm;

Expand Down Expand Up @@ -139,9 +143,9 @@ namespace pfasst
//! @}

//! @{
virtual void post() = 0;
virtual void send() = 0;
virtual void recv() = 0;
virtual void post(int tag) = 0;
virtual void send(int tag) = 0;
virtual void recv(int tag) = 0;
//! @}
};

Expand Down Expand Up @@ -218,7 +222,7 @@ namespace pfasst
* Perform one SDC sweep/iteration.
*
* Compute a correction and update solution values.
* Note that this function can assume that valid function values exist from a previous
* Note that this function can assume that valid function values exist from a previous
* pfasst::ISweeper::sweep() or pfasst::ISweeper::predict().
*/
virtual void sweep() = 0;
Expand Down Expand Up @@ -310,7 +314,7 @@ namespace pfasst
*
* @param[in,out] dst sweeper to interpolate onto (i.e. fine level)
* @param[in] src sweeper to interpolate from (i.e. coarse level)
* @param[in] interp_initial `true` if a delta for the initial condtion should also be
* @param[in] interp_initial `true` if a delta for the initial condtion should also be
* computed (PFASST)
*/
virtual void interpolate(shared_ptr<ISweeper<time>> dst,
Expand Down
45 changes: 41 additions & 4 deletions include/pfasst/mpi_communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,29 @@ using namespace std;
#include "pfasst/logging.hpp"


/**
* creates and initializes a new _empty_ `MPI_Status` object
*
* An _empty_ `MPI_Status` is defined as an `MPI_Status` object with `MPI_ERROR` as `MPI_SUCCESS`,
* `MPI_SOURCE` as `MPI_ANY_SOURCE` and `MPI_TAG` as `MPI_ANY_TAG`
* (cf. MPI Standard v3; section 3.7.3).
*
* rationale: some MPI implementations don't initialize the members of `MPI_Status` correctly
*
* @returns _empty_ `MPI_Status` object
*
* @ingroup Utilities
*/
inline static MPI_Status MPI_Status_factory()
{
MPI_Status stat;
stat.MPI_ERROR = MPI_SUCCESS;
stat.MPI_SOURCE = MPI_ANY_SOURCE;
stat.MPI_TAG = MPI_ANY_TAG;
return stat;
}


namespace pfasst
{
namespace mpi
Expand All @@ -19,10 +42,24 @@ namespace pfasst
: public runtime_error
{
public:
explicit MPIError(const string& msg="");
MPIError(const string& msg="");
virtual const char* what() const throw();
static MPIError from_code(const int err_code);
};

/**
* checks MPI error code
*
* In case @p err_code is not `MPI_SUCCESS` this throws MPIError with the error code looked up
* to a descriptive string as defined by the MPI implementation.
*/
inline static void check_mpi_error(const int err_code)
{
if (err_code != MPI_SUCCESS) {
throw MPIError::from_code(err_code);
}
}


// forward declare for MPICommunicator
class MPIStatus;
Expand Down Expand Up @@ -68,9 +105,9 @@ namespace pfasst
virtual void clear() override;
virtual void set_converged(bool converged) override;
virtual bool get_converged(int rank) override;
virtual void post();
virtual void send();
virtual void recv();
virtual void post(int tag) override;
virtual void send(int tag) override;
virtual void recv(int tag) override;
};
} // ::pfasst::mpi
} // ::pfasst
Expand Down
32 changes: 18 additions & 14 deletions src/pfasst/controller/pfasst_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ namespace pfasst
}

fine->send(comm, tag(l), false);

trns->restrict(crse, fine, true);

trns->fas(this->get_time_step(), crse, fine);
crse->save();

Expand All @@ -118,11 +116,8 @@ namespace pfasst
auto trns = level_iter.transfer();

trns->interpolate(fine, crse, true);

if (this->comm->status->previous_is_iterating()) {
fine->recv(comm, tag(level_iter), false);
trns->interpolate_initial(fine, crse);
}
fine->recv(comm, tag(level_iter), false);
trns->interpolate_initial(fine, crse);

if (level_iter < this->finest()) {
perform_sweeps(level_iter.level);
Expand All @@ -139,10 +134,11 @@ namespace pfasst
if (this->comm->status->previous_is_iterating()) {
crse->recv(comm, tag(level_iter), true);
}
this->comm->status->recv();
this->comm->status->recv(stag(level_iter));
this->perform_sweeps(level_iter.level);
crse->send(comm, tag(level_iter), true);
this->comm->status->send();
this->comm->status->set_converged(!this->comm->status->keep_iterating());
this->comm->status->send(stag(level_iter));
return level_iter + 1;
}

Expand Down Expand Up @@ -208,13 +204,19 @@ namespace pfasst
* @internals
* A simple formula is used with current level index \\( L \\) (provided by @p level_iter) and
* current iteration number \\( I \\):
* \\[ L * 10000 + I + 10 \\]
* \\[ (L+1) * 10000 + I \\]
* @endinternals
*/
template<typename time>
int PFASST<time>::tag(LevelIter level_iter)
{
return level_iter.level * 10000 + this->get_iteration() + 10;
return (level_iter.level+1) * 10000 + this->get_iteration();
}

template<typename time>
int PFASST<time>::stag(LevelIter level_iter)
{
return level_iter.level * 1000 + this->get_iteration();
}

/**
Expand All @@ -224,9 +226,11 @@ namespace pfasst
template<typename time>
void PFASST<time>::post()
{
this->comm->status->post();
for (auto l = this->coarsest() + 1; l <= this->finest(); ++l) {
l.current()->post(comm, tag(l));
if (this->comm->status->previous_is_iterating()) {
this->comm->status->post(0);
for (auto l = this->coarsest() + 1; l <= this->finest(); ++l) {
l.current()->post(comm, tag(l));
}
}
}
} // ::pfasst
90 changes: 59 additions & 31 deletions src/pfasst/encap/vector_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ namespace pfasst
: VectorEncapsulation(dynamic_cast<VectorEncapsulation<scalar, time>&&>(other))
{}

template<typename scalar, typename time>
VectorEncapsulation<scalar, time>::~VectorEncapsulation()
{
#ifdef WITH_MPI
if (this->send_request != MPI_REQUEST_NULL) {
MPI_Status stat = MPI_Status_factory();
CLOG(DEBUG, "Encap") << "waiting for open send request";
int err = MPI_Wait(&(this->send_request), &stat);
check_mpi_error(err);
CLOG(DEBUG, "Encap") << "waited for open send request";
}
assert(this->recv_request == MPI_REQUEST_NULL);
assert(this->send_request == MPI_REQUEST_NULL);
#endif
}

template<typename scalar, typename time>
void VectorEncapsulation<scalar, time>::zero()
{
Expand Down Expand Up @@ -165,18 +181,23 @@ namespace pfasst
}

#ifdef WITH_MPI
template<typename scalar, typename time>
template<typename scalar, typename time>
void VectorEncapsulation<scalar, time>::post(ICommunicator* comm, int tag)
{
auto& mpi = as_mpi(comm);
if (mpi.size() == 1) { return; }
if (mpi.rank() == 0) { return; }

int err = MPI_Irecv(this->data(), sizeof(scalar) * this->size(), MPI_CHAR,
(mpi.rank() - 1) % mpi.size(), tag, mpi.comm, &recv_request);
if (err != MPI_SUCCESS) {
throw MPIError();
if (this->recv_request != MPI_REQUEST_NULL) {
throw MPIError("a previous receive request is still open");
}

int src = (mpi.rank() - 1) % mpi.size();
CLOG(DEBUG, "Encap") << "non-blocking receiving from rank " << src << " with tag=" << tag;
int err = MPI_Irecv(this->data(), sizeof(scalar) * this->size(), MPI_CHAR,
src, tag, mpi.comm, &this->recv_request);
check_mpi_error(err);
CLOG(DEBUG, "Encap") << "non-blocking received from rank " << src << " with tag=" << tag;
}

template<typename scalar, typename time>
Expand All @@ -186,18 +207,23 @@ namespace pfasst
if (mpi.size() == 1) { return; }
if (mpi.rank() == 0) { return; }

int err;
MPI_Status stat = MPI_Status_factory();
int err = MPI_SUCCESS;

if (blocking) {
MPI_Status stat;
int src = (mpi.rank() - 1) % mpi.size();
CLOG(DEBUG, "Encap") << "blocking receive from rank " << src << " with tag=" << tag;
err = MPI_Recv(this->data(), sizeof(scalar) * this->size(), MPI_CHAR,
(mpi.rank() - 1) % mpi.size(), tag, mpi.comm, &stat);
src, tag, mpi.comm, &stat);
check_mpi_error(err);
CLOG(DEBUG, "Encap") << "received blocking from rank " << src << " with tag=" << tag << ": " << stat;
} else {
MPI_Status stat;
err = MPI_Wait(&recv_request, &stat);
}

if (err != MPI_SUCCESS) {
throw MPIError();
if (this->recv_request != MPI_REQUEST_NULL) {
CLOG(DEBUG, "Encap") << "waiting on last receive request";
err = MPI_Wait(&(this->recv_request), &stat);
check_mpi_error(err);
CLOG(DEBUG, "Encap") << "waited on last receive request: " << stat;
}
}
}

Expand All @@ -208,36 +234,38 @@ namespace pfasst
if (mpi.size() == 1) { return; }
if (mpi.rank() == mpi.size() - 1) { return; }

MPI_Status stat = MPI_Status_factory();
int err = MPI_SUCCESS;
int dest = (mpi.rank() + 1) % mpi.size();

if (blocking) {
err = MPI_Send(this->data(), sizeof(scalar) * this->size(), MPI_CHAR,
(mpi.rank() + 1) % mpi.size(), tag, mpi.comm);
CLOG(DEBUG, "Encap") << "blocking send to rank " << dest << " with tag=" << tag;
err = MPI_Send(this->data(), sizeof(scalar) * this->size(), MPI_CHAR, dest, tag, mpi.comm);
check_mpi_error(err);
CLOG(DEBUG, "Encap") << "sent blocking to rank " << dest << " with tag=" << tag;
} else {
MPI_Status stat;
err = MPI_Wait(&send_request, &stat);
if (err != MPI_SUCCESS) {
throw MPIError();
}

// got never in here
CLOG(DEBUG, "Encap") << "waiting on last send request to finish";
err = MPI_Wait(&(this->send_request), &stat);
check_mpi_error(err);
CLOG(DEBUG, "Encap") << "waited on last send request: " << stat;
CLOG(DEBUG, "Encap") << "non-blocking sending to rank " << dest << " with tag=" << tag;
err = MPI_Isend(this->data(), sizeof(scalar) * this->size(), MPI_CHAR,
(mpi.rank() + 1) % mpi.size(), tag, mpi.comm, &send_request);
}

if (err != MPI_SUCCESS) {
throw MPIError();
dest, tag, mpi.comm, &(this->send_request));
check_mpi_error(err);
CLOG(DEBUG, "Encap") << "sent non-blocking to rank " << dest << " with tag=" << tag;
}
}

template<typename scalar, typename time>
void VectorEncapsulation<scalar, time>::broadcast(ICommunicator* comm)
{
auto& mpi = as_mpi(comm);
CLOG(DEBUG, "Encap") << "broadcasting";
int err = MPI_Bcast(this->data(), sizeof(scalar) * this->size(), MPI_CHAR,
comm->size()-1, mpi.comm);

if (err != MPI_SUCCESS) {
throw MPIError();
}
check_mpi_error(err);
CLOG(DEBUG, "Encap") << "broadcasted";
}
#endif

Expand Down
7 changes: 6 additions & 1 deletion src/pfasst/interfaces_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using namespace std;

#include "pfasst/globals.hpp"
#include "pfasst/logging.hpp"


namespace pfasst
Expand Down Expand Up @@ -78,7 +79,11 @@ namespace pfasst
if (this->comm->rank() == 0) {
return !this->get_converged(0);
}
return !this->get_converged(this->comm->rank()) || !this->get_converged(this->comm->rank() - 1);
bool keep_iterating = !this->get_converged(this->comm->rank() - 1) || !this->get_converged(this->comm->rank());
CLOG(DEBUG, "Controller") << "previous converged: " << boolalpha << this->get_converged(this->comm->rank() - 1)
<< "; this converged: " << boolalpha << this->get_converged(this->comm->rank())
<< " --> keep iterating: " << boolalpha << keep_iterating;
return keep_iterating;
}


Expand Down
Loading

0 comments on commit 44e6d86

Please sign in to comment.