Skip to content

Commit

Permalink
Merge pull request #864 from RuleOfThrees/multiasync
Browse files Browse the repository at this point in the history
MultiAsync
  • Loading branch information
COM8 authored Jan 26, 2023
2 parents 47438c7 + 3e0570e commit 0817715
Show file tree
Hide file tree
Showing 16 changed files with 774 additions and 30 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ _site/

# Vim
.ycm_extra_conf.py*
*.swp

# VSCode
.vscode/
Expand All @@ -52,5 +53,9 @@ _site/
# clangd
.cache/

# compilation database
# used in various editor configurations, such as vim & YcM
compile_commands.json

# macOS
.DS_Store
1 change: 1 addition & 0 deletions cpr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_library(cpr
async.cpp
auth.cpp
bearer.cpp
callback.cpp
cert_info.cpp
cookies.cpp
cprtypes.cpp
Expand Down
14 changes: 14 additions & 0 deletions cpr/callback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <cpr/callback.h>
#include <curl/curl.h>
#include <functional>

namespace cpr {

void CancellationCallback::SetProgressCallback(ProgressCallback& u_cb) {
user_cb.emplace(std::reference_wrapper{u_cb});
}
bool CancellationCallback::operator()(cpr_pf_arg_t dltotal, cpr_pf_arg_t dlnow, cpr_pf_arg_t ultotal, cpr_pf_arg_t ulnow) const {
const bool cont_operation{!cancellation_state->load()};
return user_cb ? (cont_operation && (*user_cb)(dltotal, dlnow, ultotal, ulnow)) : cont_operation;
}
} // namespace cpr
21 changes: 19 additions & 2 deletions cpr/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,15 @@ void Session::SetWriteCallback(const WriteCallback& write) {

void Session::SetProgressCallback(const ProgressCallback& progress) {
progresscb_ = progress;
if (isCancellable) {
cancellationcb_.SetProgressCallback(progresscb_);
return;
}
#if LIBCURL_VERSION_NUM < 0x072000
curl_easy_setopt(curl_->handle, CURLOPT_PROGRESSFUNCTION, cpr::util::progressUserFunction);
curl_easy_setopt(curl_->handle, CURLOPT_PROGRESSFUNCTION, cpr::util::progressUserFunction<ProgressCallback>);
curl_easy_setopt(curl_->handle, CURLOPT_PROGRESSDATA, &progresscb_);
#else
curl_easy_setopt(curl_->handle, CURLOPT_XFERINFOFUNCTION, cpr::util::progressUserFunction);
curl_easy_setopt(curl_->handle, CURLOPT_XFERINFOFUNCTION, cpr::util::progressUserFunction<ProgressCallback>);
curl_easy_setopt(curl_->handle, CURLOPT_XFERINFODATA, &progresscb_);
#endif
curl_easy_setopt(curl_->handle, CURLOPT_NOPROGRESS, 0L);
Expand Down Expand Up @@ -968,4 +972,17 @@ void Session::SetOption(const ReserveSize& reserve_size) { SetReserveSize(reserv
void Session::SetOption(const AcceptEncoding& accept_encoding) { SetAcceptEncoding(accept_encoding); }
void Session::SetOption(AcceptEncoding&& accept_encoding) { SetAcceptEncoding(accept_encoding); }
// clang-format on

void Session::SetCancellationParam(std::shared_ptr<std::atomic_bool> param) {
cancellationcb_ = CancellationCallback{std::move(param)};
isCancellable = true;
#if LIBCURL_VERSION_NUM < 0x072000
curl_easy_setopt(curl_->handle, CURLOPT_PROGRESSFUNCTION, cpr::util::progressUserFunction<CancellationCallback>);
curl_easy_setopt(curl_->handle, CURLOPT_PROGRESSDATA, &cancellationcb_);
#else
curl_easy_setopt(curl_->handle, CURLOPT_XFERINFOFUNCTION, cpr::util::progressUserFunction<CancellationCallback>);
curl_easy_setopt(curl_->handle, CURLOPT_XFERINFODATA, &cancellationcb_);
#endif
curl_easy_setopt(curl_->handle, CURLOPT_NOPROGRESS, 0L);
}
} // namespace cpr
8 changes: 0 additions & 8 deletions cpr/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,6 @@ size_t writeUserFunction(char* ptr, size_t size, size_t nmemb, const WriteCallba
return (*write)({ptr, size}) ? size : 0;
}

#if LIBCURL_VERSION_NUM < 0x072000
int progressUserFunction(const ProgressCallback* progress, double dltotal, double dlnow, double ultotal, double ulnow) {
#else
int progressUserFunction(const ProgressCallback* progress, curl_off_t dltotal, curl_off_t dlnow, curl_off_t ultotal, curl_off_t ulnow) {
#endif
return (*progress)(dltotal, dlnow, ultotal, ulnow) ? 0 : 1;
} // namespace cpr::util

int debugUserFunction(CURL* /*handle*/, curl_infotype type, char* data, size_t size, const DebugCallback* debug) {
(*debug)(static_cast<DebugCallback::InfoType>(type), std::string(data, size));
return 0;
Expand Down
1 change: 1 addition & 0 deletions include/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ target_sources(cpr PRIVATE
cpr/accept_encoding.h
cpr/api.h
cpr/async.h
cpr/async_wrapper.h
cpr/auth.h
cpr/bearer.h
cpr/body.h
Expand Down
83 changes: 79 additions & 4 deletions include/cpr/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <utility>

#include "cpr/async.h"
#include "cpr/async_wrapper.h"
#include "cpr/auth.h"
#include "cpr/bearer.h"
#include "cpr/cprtypes.h"
Expand All @@ -17,11 +18,10 @@
#include "cpr/response.h"
#include "cpr/session.h"
#include <cpr/filesystem.h>
#include <utility>

namespace cpr {

using AsyncResponse = std::future<Response>;
using AsyncResponse = AsyncWrapper<Response>;

namespace priv {

Expand Down Expand Up @@ -85,6 +85,32 @@ void setup_multiperform(MultiPerform& multiperform, Ts&&... ts) {
setup_multiperform_internal<Ts...>(multiperform, std::forward<Ts>(ts)...);
}

using session_action_t = cpr::Response (cpr::Session::*)();

template <session_action_t SessionAction, typename T>
void setup_multiasync(std::vector<AsyncWrapper<Response, true>>& responses, T&& parameters) {
std::shared_ptr<std::atomic_bool> cancellation_state = std::make_shared<std::atomic_bool>(false);

std::function<Response(T)> execFn{[cancellation_state](T params) {
if (cancellation_state->load()) {
return Response{};
}
cpr::Session s{};
s.SetCancellationParam(cancellation_state);
apply_set_option(s, std::forward<T>(params));
return std::invoke(SessionAction, s);
}};
responses.emplace_back(GlobalThreadPool::GetInstance()->Submit(std::move(execFn), std::forward<T>(parameters)), std::move(cancellation_state));
}

template <session_action_t SessionAction, typename T, typename... Ts>
void setup_multiasync(std::vector<AsyncWrapper<Response, true>>& responses, T&& head, Ts&&... tail) {
setup_multiasync<SessionAction>(responses, std::forward<T>(head));
if constexpr (sizeof...(Ts) > 0) {
setup_multiasync<SessionAction>(responses, std::forward<Ts>(tail)...);
}
}

} // namespace priv

// Get methods
Expand Down Expand Up @@ -245,7 +271,7 @@ Response Download(std::ofstream& file, Ts&&... ts) {
// Download async method
template <typename... Ts>
AsyncResponse DownloadAsync(fs::path local_path, Ts... ts) {
return std::async(
return AsyncWrapper{std::async(
std::launch::async,
[](fs::path local_path_, Ts... ts_) {
#ifdef CPR_USE_BOOST_FILESYSTEM
Expand All @@ -255,7 +281,7 @@ AsyncResponse DownloadAsync(fs::path local_path, Ts... ts) {
#endif
return Download(f, std::move(ts_)...);
},
std::move(local_path), std::move(ts)...);
std::move(local_path), std::move(ts)...)};
}

// Download with user callback
Expand Down Expand Up @@ -316,6 +342,55 @@ std::vector<Response> MultiPost(Ts&&... ts) {
return multiperform.Post();
}

template <typename... Ts>
std::vector<AsyncWrapper<Response, true>> MultiGetAsync(Ts&&... ts) {
std::vector<AsyncWrapper<Response, true>> ret{};
priv::setup_multiasync<&cpr::Session::Get>(ret, std::forward<Ts>(ts)...);
return ret;
}

template <typename... Ts>
std::vector<AsyncWrapper<Response, true>> MultiDeleteAsync(Ts&&... ts) {
std::vector<AsyncWrapper<Response, true>> ret{};
priv::setup_multiasync<&cpr::Session::Delete>(ret, std::forward<Ts>(ts)...);
return ret;
}

template <typename... Ts>
std::vector<AsyncWrapper<Response, true>> MultiHeadAsync(Ts&&... ts) {
std::vector<AsyncWrapper<Response, true>> ret{};
priv::setup_multiasync<&cpr::Session::Head>(ret, std::forward<Ts>(ts)...);
return ret;
}
template <typename... Ts>
std::vector<AsyncWrapper<Response, true>> MultiOptionsAsync(Ts&&... ts) {
std::vector<AsyncWrapper<Response, true>> ret{};
priv::setup_multiasync<&cpr::Session::Options>(ret, std::forward<Ts>(ts)...);
return ret;
}

template <typename... Ts>
std::vector<AsyncWrapper<Response, true>> MultiPatchAsync(Ts&&... ts) {
std::vector<AsyncWrapper<Response, true>> ret{};
priv::setup_multiasync<&cpr::Session::Patch>(ret, std::forward<Ts>(ts)...);
return ret;
}

template <typename... Ts>
std::vector<AsyncWrapper<Response, true>> MultiPostAsync(Ts&&... ts) {
std::vector<AsyncWrapper<Response, true>> ret{};
priv::setup_multiasync<&cpr::Session::Post>(ret, std::forward<Ts>(ts)...);
return ret;
}

template <typename... Ts>
std::vector<AsyncWrapper<Response, true>> MultiPutAsync(Ts&&... ts) {
std::vector<AsyncWrapper<Response, true>> ret{};
priv::setup_multiasync<&cpr::Session::Put>(ret, std::forward<Ts>(ts)...);
return ret;
}


} // namespace cpr

#endif
5 changes: 3 additions & 2 deletions include/cpr/async.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CPR_ASYNC_H
#define CPR_ASYNC_H

#include "async_wrapper.h"
#include "singleton.h"
#include "threadpool.h"

Expand All @@ -16,14 +17,14 @@ class GlobalThreadPool : public ThreadPool {
};

/**
* Return a future, calling future.get() will wait task done and return RetType.
* Return a wrapper for a future, calling future.get() will wait until the task is done and return RetType.
* async(fn, args...)
* async(std::bind(&Class::mem_fn, &obj))
* async(std::mem_fn(&Class::mem_fn, &obj))
**/
template <class Fn, class... Args>
auto async(Fn&& fn, Args&&... args) {
return GlobalThreadPool::GetInstance()->Submit(std::forward<Fn>(fn), std::forward<Args>(args)...);
return AsyncWrapper{GlobalThreadPool::GetInstance()->Submit(std::forward<Fn>(fn), std::forward<Args>(args)...)};
}

class async {
Expand Down
140 changes: 140 additions & 0 deletions include/cpr/async_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#ifndef CPR_ASYNC_WRAPPER_H
#define CPR_ASYNC_WRAPPER_H

#include <atomic>
#include <future>
#include <memory>

#include "cpr/response.h"

namespace cpr {
enum class [[nodiscard]] CancellationResult { failure, success, invalid_operation };

/**
* A class template intended to wrap results of async operations (instances of std::future<T>)
* and also provide extended capablilities relaed to these requests, for example cancellation.
*
* The RAII semantics are the same as std::future<T> - moveable, not copyable.
*/
template <typename T, bool isCancellable = false>
class AsyncWrapper {
private:
std::future<T> future;
std::shared_ptr<std::atomic_bool> is_cancelled;

public:
// Constructors
explicit AsyncWrapper(std::future<T>&& f) : future{std::move(f)} {}
AsyncWrapper(std::future<T>&& f, std::shared_ptr<std::atomic_bool>&& cancelledState) : future{std::move(f)}, is_cancelled{std::move(cancelledState)} {}

// Copy Semantics
AsyncWrapper(const AsyncWrapper&) = delete;
AsyncWrapper& operator=(const AsyncWrapper&) = delete;

// Move Semantics
AsyncWrapper(AsyncWrapper&&) noexcept = default;
AsyncWrapper& operator=(AsyncWrapper&&) noexcept = default;

// Destructor
~AsyncWrapper() {
if constexpr (isCancellable) {
if(is_cancelled) {
is_cancelled->store(true);
}
}
}
// These methods replicate the behaviour of std::future<T>
[[nodiscard]] T get() {
if constexpr (isCancellable) {
if (IsCancelled()) {
throw std::logic_error{"Calling AsyncWrapper::get on a cancelled request!"};
}
}
if (!future.valid()) {
throw std::logic_error{"Calling AsyncWrapper::get when the associated future instance is invalid!"};
}
return future.get();
}

[[nodiscard]] bool valid() const noexcept {
if constexpr (isCancellable) {
return !is_cancelled->load() && future.valid();
} else {
return future.valid();
}
}

void wait() const {
if constexpr (isCancellable) {
if (is_cancelled->load()) {
throw std::logic_error{"Calling AsyncWrapper::wait when the associated future is invalid or cancelled!"};
}
}
if (!future.valid()) {
throw std::logic_error{"Calling AsyncWrapper::wait_until when the associated future is invalid!"};
}
future.wait();
}

template <class Rep, class Period>
std::future_status wait_for(const std::chrono::duration<Rep, Period>& timeout_duration) const {
if constexpr (isCancellable) {
if (IsCancelled()) {
throw std::logic_error{"Calling AsyncWrapper::wait_for when the associated future is cancelled!"};
}
}
if (!future.valid()) {
throw std::logic_error{"Calling AsyncWrapper::wait_until when the associated future is invalid!"};
}
return future.wait_for(timeout_duration);
}

template <class Clock, class Duration>
std::future_status wait_until(const std::chrono::time_point<Clock, Duration>& timeout_time) const {
if constexpr (isCancellable) {
if (IsCancelled()) {
throw std::logic_error{"Calling AsyncWrapper::wait_until when the associated future is cancelled!"};
}
}
if (!future.valid()) {
throw std::logic_error{"Calling AsyncWrapper::wait_until when the associated future is invalid!"};
}
return future.wait_until(timeout_time);
}

std::shared_future<T> share() noexcept {
return future.share();
}

// Cancellation-related methods
CancellationResult Cancel() {
if constexpr (!isCancellable) {
return CancellationResult::invalid_operation;
}
if (!future.valid() || is_cancelled->load()) {
return CancellationResult::invalid_operation;
}
is_cancelled->store(true);
return CancellationResult::success;
}

[[nodiscard]] bool IsCancelled() const {
if constexpr (isCancellable) {
return is_cancelled->load();
} else {
return false;
}
}
};

// Deduction guides
template <typename T>
AsyncWrapper(std::future<T>&&) -> AsyncWrapper<T, false>;

template <typename T>
AsyncWrapper(std::future<T>&&, std::shared_ptr<std::atomic_bool>&&) -> AsyncWrapper<T, true>;

} // namespace cpr


#endif
Loading

0 comments on commit 0817715

Please sign in to comment.