Skip to content

Commit

Permalink
Update aiohttp to use client certificate, too (untested)
Browse files Browse the repository at this point in the history
  • Loading branch information
quality-leftovers committed Mar 7, 2024
1 parent 1f4bf8a commit 3843d6a
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 44 deletions.
2 changes: 1 addition & 1 deletion tests/test_async_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_upload_chunk(self):
request_length = self.async_uploader.get_request_length()
self.loop.run_until_complete(self.async_uploader.upload_chunk())
self.assertEqual(self.async_uploader.offset, request_length)

def test_upload_chunk_with_creation(self):
with aioresponses() as resps:
resps.post(
Expand Down
29 changes: 0 additions & 29 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,3 @@ def test_async_uploader(self):

self.assertIsInstance(async_uploader, AsyncUploader)
self.assertEqual(async_uploader.client, self.client)

class TusClientTestWithClientCertificate(unittest.TestCase):
def setUp(self):
self.client_cert=("/tmp/client.crt.pem")
self.client = client.TusClient('http://tusd.tusdemo.net/files/',
headers={'foo': 'bar'},
client_cert=self.client_cert)

@responses.activate
def test_uploader(self):
url = 'http://tusd.tusdemo.net/files/15acd89eabdf5738ffc'
responses.add(responses.HEAD, url,
adding_headers={"upload-offset": "0"})
uploader = self.client.uploader('./LICENSE', url=url)

self.assertIsInstance(uploader, Uploader)
self.assertEqual(uploader.client, self.client)
self.assertEqual(uploader.client_cert, self.client_cert)

@responses.activate
def test_async_uploader(self):
url = 'http://tusd.tusdemo.net/files/15acd89eabdf5738ffc'
responses.add(responses.HEAD, url,
adding_headers={"upload-offset": "0"})
async_uploader = self.client.async_uploader('./LICENSE', url=url)

self.assertIsInstance(async_uploader, AsyncUploader)
self.assertEqual(async_uploader.client, self.client)
self.assertEqual(async_uploader.client_cert, self.client_cert)
2 changes: 0 additions & 2 deletions tusclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ def uploader(self, *args, **kwargs) -> Uploader:
see tusclient.uploader.Uploader for required and optional arguments.
"""
kwargs['client'] = self
kwargs['client_cert'] = self.client_cert
return Uploader(*args, **kwargs)

def async_uploader(self, *args, **kwargs) -> AsyncUploader:
kwargs['client'] = self
kwargs['client_cert'] = self.client_cert
return AsyncUploader(*args, **kwargs)
26 changes: 19 additions & 7 deletions tusclient/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import requests
import aiohttp
import ssl

from tusclient.exceptions import TusUploadFailed, TusCommunicationError

Expand Down Expand Up @@ -80,10 +81,14 @@ def perform(self):
try:
chunk = self.file.read(self._content_length)
self.add_checksum(chunk)
resp = requests.patch(self._url, data=chunk,
headers=self._request_headers,
verify=self.verify_tls_cert,
cert=self.client_cert,)
resp = requests.patch(
self._url,
data=chunk,
headers=self._request_headers,
verify=self.verify_tls_cert,
stream=True,
cert=self.client_cert
)
self.status_code = resp.status_code
self.response_content = resp.content
self.response_headers = {k.lower(): v for k, v in resp.headers.items()}
Expand All @@ -107,10 +112,17 @@ async def perform(self):
chunk = self.file.read(self._content_length)
self.add_checksum(chunk)
try:
async with aiohttp.ClientSession(loop=self.io_loop) as session:
ssl = None if self.verify_tls_cert else False
ssl_ctx = ssl.create_default_context()
if (self.client_cert is not None):
if self.client_cert is str:
ssl_ctx.load_cert_chain(certfile=self.client_cert)
else:
ssl_ctx.load_cert_chain(certfile=self.client_cert[0], keyfile=self.client_cert[1])
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
async with aiohttp.ClientSession(loop=self.io_loop, connector=conn) as session:
verify_tls_cert = None if self.verify_tls_cert else False
async with session.patch(
self._url, data=chunk, headers=self._request_headers, ssl=ssl
self._url, data=chunk, headers=self._request_headers, ssl=verify_tls_cert
) as resp:
self.status_code = resp.status
self.response_headers = {
Expand Down
7 changes: 5 additions & 2 deletions tusclient/uploader/baseuploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def __init__(
url_storage: Optional[Storage] = None,
fingerprinter: Optional[interface.Fingerprint] = None,
upload_checksum=False,
client_cert: Optional[Tuple[str, str]] = None,
):
if file_path is None and file_stream is None:
raise ValueError("Either 'file_path' or 'file_stream' cannot be None.")
Expand All @@ -132,7 +131,6 @@ def __init__(
self.file_stream = file_stream
self.stop_at = self.get_file_size()
self.client = client
self.client_cert = client_cert
self.metadata = metadata or {}
self.metadata_encoding = metadata_encoding
self.store_url = store_url
Expand Down Expand Up @@ -179,6 +177,11 @@ def checksum_algorithm_name(self):
"""
return self.__checksum_algorithm_name

@property
def client_cert(self):
"""The client certificate used for the configured client"""
return self.client.client_cert if self.client is not None else None

@catch_requests_error
def get_offset(self):
"""
Expand Down
14 changes: 11 additions & 3 deletions tusclient/uploader/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import requests
import aiohttp
import ssl

from tusclient.uploader.baseuploader import BaseUploader

Expand Down Expand Up @@ -150,11 +151,18 @@ async def create_url(self):
Makes request to tus server to create a new upload url for the required file upload.
"""
try:
async with aiohttp.ClientSession() as session:
ssl_ctx = ssl.create_default_context()
if (self.client_cert is not None):
if self.client_cert is str:
ssl_ctx.load_cert_chain(certfile=self.client_cert)
else:
ssl_ctx.load_cert_chain(certfile=self.client_cert[0], keyfile=self.client_cert[1])
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
async with aiohttp.ClientSession(connector=conn) as session:
headers = self.get_url_creation_headers()
ssl = None if self.verify_tls_cert else False
verify_tls_cert = None if self.verify_tls_cert else False
async with session.post(
self.client.url, headers=headers, ssl=ssl
self.client.url, headers=headers, ssl=verify_tls_cert
) as resp:
url = resp.headers.get("location")
if url is None:
Expand Down

0 comments on commit 3843d6a

Please sign in to comment.