Skip to content

Commit

Permalink
added OIDC support
Browse files Browse the repository at this point in the history
  • Loading branch information
tmaeno committed Oct 1, 2020
1 parent b075c48 commit b5dff6a
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 17 deletions.
3 changes: 3 additions & 0 deletions ChangeLog.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
** Release Notes

1.4.40
* added OIDC support

1.4.39
* to ignore --rootVer when --useAthenaPackage/athenaTag

Expand Down
67 changes: 51 additions & 16 deletions pandatools/Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import MiscUtils
from .MiscUtils import commands_get_status_output, commands_get_output, pickle_loads
from . import PLogger
from . import openidc_utils

# configuration
try:
Expand Down Expand Up @@ -88,15 +89,37 @@ def __init__(self):
# path to curl
self.path = 'curl --user-agent "dqcurl" '
# verification of the host certificate
self.verifyHost = True
if 'PANDA_VERIFY_HOST' in os.environ and os.environ['PANDA_VERIFY_HOST'] == 'off':
self.verifyHost = False
else:
self.verifyHost = True
# request a compressed response
self.compress = True
# SSL cert/key
self.sslCert = ''
self.sslKey = ''
# auth mode
self.idToken = None
if 'PANDA_AUTH' in os.environ and os.environ['PANDA_AUTH'] == 'oidc':
self.authMode = 'oidc'
else:
self.authMode = 'voms'
# verbose
self.verbose = False

# get token
def getToken(self):
tmp_log = PLogger.getPandaLogger()
oidc = openidc_utils.OpenIdConnect_Utils(os.environ['PANDA_CONFIG_ROOT'], tmp_log, self.verbose)
parsed = urlparse(baseURLSSL)
auth_url = '{0}://{1}:{2}/auth/config.json'.format(parsed.scheme, parsed.hostname, parsed.port)
s, o = oidc.run_device_authorization_flow(auth_url)
if not s:
tmp_log.error(o)
return False
self.idToken = o
return True

# randomize IP
def randomize_ip(self, url):
# parse URL
Expand Down Expand Up @@ -124,11 +147,15 @@ def get(self,url,data,rucioAccount=False, via_file=False):
com += ' --capath %s' % tmp_x509_CApath
if self.compress:
com += ' --compressed'
if self.sslCert != '':
com += ' --cert %s' % self.sslCert
com += ' --cacert %s' % self.sslCert
if self.sslKey != '':
com += ' --key %s' % self.sslKey
if self.authMode == 'oidc':
self.getToken()
com += ' -H "Authorization: Bearer {0}"'.format(self.idToken)
else:
if self.sslCert != '':
com += ' --cert %s' % self.sslCert
com += ' --cacert %s' % self.sslCert
if self.sslKey != '':
com += ' --key %s' % self.sslKey
# max time of 10 min
com += ' -m 600'
# add rucio account info
Expand Down Expand Up @@ -191,11 +218,15 @@ def post(self,url,data,rucioAccount=False, is_json=False, via_file=False):
com += ' --capath %s' % tmp_x509_CApath
if self.compress:
com += ' --compressed'
if self.sslCert != '':
com += ' --cert %s' % self.sslCert
com += ' --cacert %s' % self.sslCert
if self.sslKey != '':
com += ' --key %s' % self.sslKey
if self.authMode == 'oidc':
self.getToken()
com += ' -H "Authorization: Bearer {0}"'.format(self.idToken)
else:
if self.sslCert != '':
com += ' --cert %s' % self.sslCert
com += ' --cacert %s' % self.sslCert
if self.sslKey != '':
com += ' --key %s' % self.sslKey
# max time of 10 min
com += ' -m 600'
# add rucio account info
Expand Down Expand Up @@ -261,11 +292,15 @@ def put(self,url,data):
com += ' --capath %s' % tmp_x509_CApath
if self.compress:
com += ' --compressed'
if self.sslCert != '':
com += ' --cert %s' % self.sslCert
com += ' --cacert %s' % self.sslCert
if self.sslKey != '':
com += ' --key %s' % self.sslKey
if self.authMode == 'oidc':
self.getToken()
com += ' -H "Authorization: Bearer {0}"'.format(self.idToken)
else:
if self.sslCert != '':
com += ' --cert %s' % self.sslCert
com += ' --cacert %s' % self.sslCert
if self.sslKey != '':
com += ' --key %s' % self.sslKey
# emulate PUT
for key in data.keys():
com += ' -F "%s=@%s"' % (key,data[key])
Expand Down
2 changes: 1 addition & 1 deletion pandatools/PandaToolsPkgInfo.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
release_version = "1.4.39"
release_version = "1.4.40"
217 changes: 217 additions & 0 deletions pandatools/openidc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import os
import uuid
import json
import time
import glob
import base64
import datetime

try:
from urllib import urlencode
from urllib2 import urlopen, Request, HTTPError
except ImportError:
from urllib.request import urlopen, Request
from urllib.parse import urlencode
from urllib.error import HTTPError


TOKEN_BASENAME = '.token'
CACHE_PREFIX = '.page_cache_'


class OpenIdConnect_Utils:

# constructor
def __init__(self, token_dir, log_stream, verbose=False):
self.token_dir = token_dir
self.log_stream = log_stream
self.verbose = verbose

# get token path
def get_token_path(self):
return os.path.join(self.token_dir, TOKEN_BASENAME)

# get device code
def get_device_code(self, device_auth_endpoint, client_id, audience):
if self.verbose:
self.log_stream.debug('getting device code')
data = {'client_id': client_id,
'scope': "openid profile email offline_access",
'audience': audience}
rdata = urlencode(data).encode()
req = Request(device_auth_endpoint, rdata)
req.add_header('content-type', 'application/x-www-form-urlencoded')
try:
conn = urlopen(req)
text = conn.read()
if self.verbose:
self.log_stream.debug(text)
return True, json.loads(text)
except HTTPError as e:
return False, 'code={0}. reason={1}. description={2}'.format(e.code, e.reason, e.read())
except Exception as e:
return False, str(e)

# get ID token
def get_id_token(self, token_endpoint, client_id, device_code, interval, expires_in):
if self.verbose:
self.log_stream.debug('getting ID token')
startTime = datetime.datetime.utcnow()
data = {'client_id': client_id,
'grant_type': 'urn:ietf:params:oauth:grant-type:device_code',
'device_code': device_code}
rdata = urlencode(data).encode()
req = Request(token_endpoint, rdata)
req.add_header('content-type', 'application/x-www-form-urlencoded')
while datetime.datetime.utcnow() - startTime < datetime.timedelta(seconds=expires_in):
try:
conn = urlopen(req)
text = conn.read()
if self.verbose:
self.log_stream.debug(text)
id_token = json.loads(text)['id_token']
with open(self.get_token_path(), 'w') as f:
f.write(text)
return True, id_token
except HTTPError as e:
text = e.read()
try:
description = json.loads(text)
# pending
if description['error'] == "authorization_pending":
time.sleep(interval + 1)
continue
except Exception:
pass
return False, 'code={0}. reason={1}. description={2}'.format(e.code, e.reason, text)
except Exception as e:
return False, str(e)

# refresh token
def refresh_token(self, token_endpoint, client_id, client_secret, refresh_token_string):
if self.verbose:
self.log_stream.debug('refreshing token')
data = {'client_id': client_id,
'client_secret': client_secret,
'grant_type': 'refresh_token',
'refresh_token': refresh_token_string}
rdata = urlencode(data).encode()
req = Request(token_endpoint, rdata)
req.add_header('content-type', 'application/x-www-form-urlencoded')
try:
conn = urlopen(req)
text = conn.read()
if self.verbose:
self.log_stream.debug(text)
id_token = json.loads(text)['id_token']
with open(self.get_token_path(), 'w') as f:
f.write(text)
return True, id_token
except HTTPError as e:
return False, 'code={0}. reason={1}. description={2}'.format(e.code, e.reason, e.read())
except Exception as e:
return False, str(e)

# fetch page
def fetch_page(self, url):
path = os.path.join(self.token_dir, CACHE_PREFIX + str(uuid.uuid5(uuid.NAMESPACE_URL, str(url))))
if not os.path.exists(path) or \
datetime.datetime.now() - datetime.datetime.fromtimestamp(os.path.getmtime(path)) > \
datetime.timedelta(hours=1):
if self.verbose:
self.log_stream.debug('fetching {0}'.format(url))
try:
conn = urlopen(url)
text = conn.read()
if self.verbose:
self.log_stream.debug(text)
with open(path, 'w') as f:
f.write(text)
except HTTPError as e:
return False, 'code={0}. reason={1}. description={2}'.format(e.code, e.reason, e.read())
except Exception as e:
return False, str(e)
with open(path) as f:
return True, json.load(f)

# check token expiry
def check_token_expiry(self):
token_file = self.get_token_path()
if os.path.exists(token_file):
with open(token_file) as f:
if self.verbose:
self.log_stream.debug('check {0}'.format(token_file))
try:
# decode ID token
data = json.load(f)
enc = data['id_token'].split('.')[1]
enc += '=' * (-len(enc) % 4)
dec = json.loads(base64.urlsafe_b64decode(enc.encode()))
exp_time = datetime.datetime.fromtimestamp(dec['exp'])
delta = exp_time - datetime.datetime.now()
if self.verbose:
self.log_stream.debug('token expiration time : {0}'.\
format(exp_time.strftime("%Y-%m-%d %H:%M:%S")))
# check expiration time
if delta < datetime.timedelta(minutes=10):
# return refresh token
if 'refresh_token' in data:
if self.verbose:
self.log_stream.debug('to refresh token')
return False, data['refresh_token']
else:
# return valid token
if self.verbose:
self.log_stream.debug('valid token is available')
return True, data['id_token']
except Exception as e:
self.log_stream.error('failed to decode cached token with {0}'.format(e))
if self.verbose:
self.log_stream.debug('cached token unavailable')
return False, None

# run device authorization flow
def run_device_authorization_flow(self, auth_config_url):
# check toke expiry
s, o = self.check_token_expiry()
if s:
# still valid
return True, o
refresh_token_string = o
# get auth config
s, o = self.fetch_page(auth_config_url)
if not s:
return False, "Failed to get Auth configuration"
auth_config = o
# get endpoint config
s, o = self.fetch_page(auth_config['oidc_config_url'])
if not s:
return False, "Failed to get endpoint configuration"
endpoint_config = o
# refresh token
if refresh_token_string is not None:
s, o = self.refresh_token(endpoint_config['token_endpoint'], auth_config['client_id'],
auth_config['client_secret'], refresh_token_string)
# refreshed
if s:
return True, o
# get device code
s, o = self.get_device_code(endpoint_config['device_authorization_endpoint'], auth_config['client_id'],
auth_config['audience'])
if not s:
return False, 'Failed to get device code: ' + o
# get ID token
self.log_stream.info(("Please go to {0} and sign in. "
"Waiting until authentication is completed").format(o['verification_uri_complete']))
s, o = self.get_id_token(endpoint_config['token_endpoint'], auth_config['client_id'],
o['device_code'], o['interval'], o['expires_in'])
if not s:
return False, "Failed to get ID token: " + o
self.log_stream.info('All set')
return True, o

# cleanup
def cleanup(self):
for patt in [TOKEN_BASENAME, CACHE_PREFIX]:
for f in glob.glob(os.path.join(self.token_dir, patt + '*')):
os.remove(f)

0 comments on commit b5dff6a

Please sign in to comment.