Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add require_csrf argument to add_jsonrpc_endpoint #47

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,5 @@ Contributors
- Donald Stufft, 8/11/2015

- Ben Holzman, 11/17/2015

- Antti Haapala, 8/29/2017
22 changes: 18 additions & 4 deletions pyramid_rpc/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,11 @@ def batched_request_view(request):


class Endpoint(object):
def __init__(self, name, default_mapper, default_renderer):
def __init__(self, name, default_mapper, default_renderer, require_csrf):
self.name = name
self.default_mapper = default_mapper
self.default_renderer = default_renderer
self.require_csrf = require_csrf


def add_jsonrpc_endpoint(config, name, *args, **kw):
Expand All @@ -341,17 +342,24 @@ def add_jsonrpc_endpoint(config, name, *args, **kw):
string name of the renderer, registered via
:meth:`pyramid.config.Configurator.add_renderer`.

``require_csrf``

If this argument is specified and is not ``None``, the value will
be passed as the ``require_csrf`` argument to each of the endpoint's
methods, and the batch request view and error view registration.

A JSON-RPC method also accepts all of the arguments supplied to
:meth:`pyramid.config.Configurator.add_route`.

"""
default_mapper = kw.pop('default_mapper', MapplyViewMapper)
default_renderer = kw.pop('default_renderer', DEFAULT_RENDERER)
require_csrf = kw.pop('require_csrf', None)

endpoint = Endpoint(
name,
default_mapper=default_mapper,
default_renderer=default_renderer,
require_csrf=require_csrf
)

config.registry.jsonrpc_endpoints[name] = endpoint
Expand All @@ -363,9 +371,11 @@ def add_jsonrpc_endpoint(config, name, *args, **kw):
kw['jsonrpc_batched'] = True
kw['renderer'] = null_renderer
config.add_view(batched_request_view, route_name=name,
permission=NO_PERMISSION_REQUIRED, **kw)
permission=NO_PERMISSION_REQUIRED,
require_csrf=require_csrf, **kw)
config.add_view(exception_view, route_name=name, context=Exception,
permission=NO_PERMISSION_REQUIRED)
permission=NO_PERMISSION_REQUIRED,
require_csrf=require_csrf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an exception view, it should always have require_csrf=False.

Also the JsonRpcError exception view in the includeme should also have require_csrf=False.



def add_jsonrpc_method(config, view, **kw):
Expand Down Expand Up @@ -416,6 +426,10 @@ def add_jsonrpc_method(config, view, **kw):
mapper = endpoint.default_mapper
kw['mapper'] = mapper

if 'require_csrf' not in kw and endpoint.require_csrf is not None:
# only override mapper if not supplied
kw['require_csrf'] = endpoint.require_csrf

renderer = kw.pop('renderer', None)
if renderer is None:
renderer = endpoint.default_renderer
Expand Down
108 changes: 107 additions & 1 deletion pyramid_rpc/tests/test_jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
import unittest

from pyramid import testing

from pyramid.exceptions import BadCSRFToken
from webtest import TestApp


class DummySessionFactory(object):
def __init__(self, request):
pass

def get_csrf_token(self):
return 'abc'


class Test_add_jsonrpc_method(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -537,6 +545,104 @@ def view(request, a):
result = self._callFUT(app, 'dummy', [val])
self.assertEqual(result['result'], val)

def test_require_csrf_False(self):
def view(request):
return 'this must return'

config = self.config
config.include('pyramid_rpc.jsonrpc')
config.set_default_csrf_options(require_csrf=True)
config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=False)
config.add_jsonrpc_method(view, endpoint='rpc', method='dummy')
app = config.make_wsgi_app()
app = TestApp(app)
result = self._callFUT(app, 'dummy', [], expect_error=False)
self.assertEqual(result['result'], 'this must return')

def test_require_csrf_True(self):
config = self.config
config.include('pyramid_rpc.jsonrpc')
config.set_session_factory(DummySessionFactory)
config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True)
config.add_jsonrpc_method(lambda: 'not actually called',
endpoint='rpc', method='dummy')
app = config.make_wsgi_app()
app = TestApp(app)
with self.assertRaises(BadCSRFToken):
self._callFUT(app, 'dummy', [])

def test_require_csrf_overrideable_on_method(self):
def view(request):
return 'this must return'
config = self.config
config.include('pyramid_rpc.jsonrpc')
config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True)
config.add_jsonrpc_method(view, endpoint='rpc',
method='dummy', require_csrf=False)
app = config.make_wsgi_app()
app = TestApp(app)
result = self._callFUT(app, 'dummy', [], expect_error=False)
self.assertEqual(result['result'], 'this must return')

def test_error_require_csrf_False(self):
config = self.config
config.include('pyramid_rpc.jsonrpc')
config.set_default_csrf_options(require_csrf=True)
config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=False)
app = config.make_wsgi_app()
app = TestApp(app)
result = self._callFUT(app, 'err', [], expect_error=True)
self.assertEqual(result['error']['code'], -32601) # invalid method

def test_error_require_csrf_True(self):
config = self.config
config.set_session_factory(DummySessionFactory)
config.include('pyramid_rpc.jsonrpc')
config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True)
app = config.make_wsgi_app()
app = TestApp(app)
with self.assertRaises(BadCSRFToken):
self._callFUT(app, 'err', [], expect_error=True)

def test_it_with_batched_requests_require_csrf_False(self):
def view(request, a, b):
return [a, b]
config = self.config
config.include('pyramid_rpc.jsonrpc')
config.set_default_csrf_options(require_csrf=True)
config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=False)
config.add_jsonrpc_method(view, endpoint='rpc', method='dummy')
app = config.make_wsgi_app()
app = TestApp(app)
body = [
{'id': 1, 'jsonrpc': '2.0', 'method': 'dummy', 'params': [2, 3]},
{'id': 2, 'jsonrpc': '2.0', 'method': 'dummy', 'params': {'a': 3, 'b': 2}},
]
resp = app.post('/api/jsonrpc', content_type='application/json',
params=json.dumps(body))
self.assertEqual(resp.status_int, 200)
result = resp.json
result1 = [r for r in result if r['id'] == 1][0]
result2 = [r for r in result if r['id'] == 2][0]
self.assertEqual(result1, {'id': 1, 'jsonrpc': '2.0', 'result': [2, 3]})
self.assertEqual(result2, {'id': 2, 'jsonrpc': '2.0', 'result': [3, 2]})

def test_it_with_batched_requests_require_csrf_True_must_fail(self):
config = self.config
config.set_session_factory(DummySessionFactory)
config.include('pyramid_rpc.jsonrpc')
config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True)
config.add_jsonrpc_method(lambda: 'this is not actually called', endpoint='rpc', method='dummy')
app = config.make_wsgi_app()
app = TestApp(app)
body = [
{'id': 1, 'jsonrpc': '2.0', 'method': 'dummy', 'params': [2, 3]},
{'id': 2, 'jsonrpc': '2.0', 'method': 'dummy', 'params': {'a': 3, 'b': 2}},
]
with self.assertRaises(BadCSRFToken):
app.post('/api/jsonrpc', content_type='application/json',
params=json.dumps(body))


class TestGET(unittest.TestCase):

Expand Down