diff --git a/lib/stytch/b2b_client.rb b/lib/stytch/b2b_client.rb index 9cbdb42..c081236 100644 --- a/lib/stytch/b2b_client.rb +++ b/lib/stytch/b2b_client.rb @@ -6,15 +6,17 @@ require_relative 'b2b_organizations' require_relative 'b2b_otp' require_relative 'b2b_passwords' +require_relative 'b2b_rbac' require_relative 'b2b_sessions' require_relative 'b2b_sso' require_relative 'm2m' +require_relative 'rbac_local' module StytchB2B class Client ENVIRONMENTS = %i[live test].freeze - attr_reader :discovery, :m2m, :magic_links, :oauth, :otps, :organizations, :passwords, :sso, :sessions + attr_reader :discovery, :m2m, :magic_links, :oauth, :otps, :organizations, :passwords, :rbac, :sso, :sessions def initialize(project_id:, secret:, env: nil, &block) @api_host = api_host(env, project_id) @@ -23,15 +25,19 @@ def initialize(project_id:, secret:, env: nil, &block) create_connection(&block) + rbac = StytchB2B::RBAC.new(@connection) + @policy_cache = StytchB2B::PolicyCache.new(rbac_client: rbac) + @discovery = StytchB2B::Discovery.new(@connection) - @m2m = Stytch::M2M.new(@connection, project_id) + @m2m = Stytch::M2M.new(@connection, @project_id) @magic_links = StytchB2B::MagicLinks.new(@connection) @oauth = StytchB2B::OAuth.new(@connection) @otps = StytchB2B::OTPs.new(@connection) @organizations = StytchB2B::Organizations.new(@connection) @passwords = StytchB2B::Passwords.new(@connection) + @rbac = StytchB2B::RBAC.new(@connection) @sso = StytchB2B::SSO.new(@connection) - @sessions = StytchB2B::Sessions.new(@connection, project_id) + @sessions = StytchB2B::Sessions.new(@connection, @project_id, @policy_cache) end private diff --git a/lib/stytch/b2b_rbac.rb b/lib/stytch/b2b_rbac.rb new file mode 100644 index 0000000..5b7f7f2 --- /dev/null +++ b/lib/stytch/b2b_rbac.rb @@ -0,0 +1,25 @@ +# frozen_string_literal: true + +# !!! +# WARNING: This file is autogenerated +# Only modify code within MANUAL() sections +# or your changes may be overwritten later! +# !!! + +require_relative 'request_helper' + +module StytchB2B + class RBAC + include Stytch::RequestHelper + + def initialize(connection) + @connection = connection + end + + def policy + query_params = {} + request = request_with_query_params('/v1/b2b/rbac/policy', query_params) + get_request(request) + end + end +end diff --git a/lib/stytch/b2b_sessions.rb b/lib/stytch/b2b_sessions.rb index 4b913f4..166d464 100644 --- a/lib/stytch/b2b_sessions.rb +++ b/lib/stytch/b2b_sessions.rb @@ -15,9 +15,10 @@ module StytchB2B class Sessions include Stytch::RequestHelper - def initialize(connection, project_id) + def initialize(connection, project_id, policy_cache) @connection = connection + @policy_cache = policy_cache @project_id = project_id @cache_last_update = 0 @jwks_loader = lambda do |options| @@ -95,6 +96,9 @@ def get( # delete a key, supply a null value. Custom claims made with reserved claims (`iss`, `sub`, `aud`, `exp`, `nbf`, `iat`, `jti`) will be ignored. # Total custom claims size cannot exceed four kilobytes. # The type of this field is nilable +object+. + # authorization_check:: + # (no documentation yet) + # The type of this field is nilable +AuthorizationCheck+ (+object+). # # == Returns: # An object with the following fields: @@ -123,13 +127,15 @@ def authenticate( session_token: nil, session_duration_minutes: nil, session_jwt: nil, - session_custom_claims: nil + session_custom_claims: nil, + authorization_check: nil ) request = {} request[:session_token] = session_token unless session_token.nil? request[:session_duration_minutes] = session_duration_minutes unless session_duration_minutes.nil? request[:session_jwt] = session_jwt unless session_jwt.nil? request[:session_custom_claims] = session_custom_claims unless session_custom_claims.nil? + request[:authorization_check] = authorization_check unless authorization_check.nil? post_request('/v1/b2b/sessions/authenticate', request) end @@ -325,17 +331,21 @@ def get_jwks( # If max_token_age_seconds is set and the JWT was issued (based on the "iat" claim) less than # max_token_age_seconds seconds ago, then just verify locally and don't call the API # To force remote validation for all tokens, set max_token_age_seconds to 0 or call authenticate() + # Note that the 'user_id' field of the returned session is DEPRECATED: Use member_id instead + # This field will be removed in a future MAJOR release. def authenticate_jwt( session_jwt, max_token_age_seconds: nil, session_duration_minutes: nil, - session_custom_claims: nil + session_custom_claims: nil, + authorization_check: nil ) if max_token_age_seconds == 0 return authenticate( session_jwt: session_jwt, session_duration_minutes: session_duration_minutes, - session_custom_claims: session_custom_claims + session_custom_claims: session_custom_claims, + authorization_check: authorization_check ) end @@ -343,12 +353,21 @@ def authenticate_jwt( iat_time = Time.at(decoded_jwt['iat']).to_datetime if iat_time + max_token_age_seconds >= Time.now session = marshal_jwt_into_session(decoded_jwt) - { 'session' => session } + if authorization_check && session['roles'] + @policy_cache.perform_authorization_check( + subject_roles: session['roles'], + subject_org_id: session['organization_id'], + authorization_check: authorization_check + ) + end + + { 'session' => session['member_session'] } else authenticate( session_jwt: session_jwt, session_duration_minutes: session_duration_minutes, - session_custom_claims: session_custom_claims + session_custom_claims: session_custom_claims, + authorization_check: authorization_check ) end rescue StandardError @@ -356,7 +375,8 @@ def authenticate_jwt( authenticate( session_jwt: session_jwt, session_duration_minutes: session_duration_minutes, - session_custom_claims: session_custom_claims + session_custom_claims: session_custom_claims, + authorization_check: authorization_check ) end @@ -381,24 +401,35 @@ def authenticate_jwt_local(session_jwt) end end + # Note that the 'user_id' field is DEPRECATED: Use member_id instead + # This field will be removed in a future MAJOR release. def marshal_jwt_into_session(jwt) stytch_claim = 'https://stytch.com/session' + organization_claim = 'https://stytch.com/organization' + roles_claim = 'https://stytch.com/roles' + expires_at = jwt[stytch_claim]['expires_at'] || Time.at(jwt['exp']).to_datetime.utc.strftime('%Y-%m-%dT%H:%M:%SZ') # The custom claim set is all the claims in the payload except for the standard claims and # the Stytch session claim. The cleanest way to collect those seems to be naming what we want # to omit and filtering the rest to collect the custom claims. - reserved_claims = ['aud', 'exp', 'iat', 'iss', 'jti', 'nbf', 'sub', stytch_claim] + reserved_claims = ['aud', 'exp', 'iat', 'iss', 'jti', 'nbf', 'sub', stytch_claim, organization_claim, roles_claim] custom_claims = jwt.reject { |key, _| reserved_claims.include?(key) } { - 'session_id' => jwt[stytch_claim]['id'], - 'user_id' => jwt['sub'], - 'started_at' => jwt[stytch_claim]['started_at'], - 'last_accessed_at' => jwt[stytch_claim]['last_accessed_at'], - # For JWTs that include it, prefer the inner expires_at claim. - 'expires_at' => expires_at, - 'attributes' => jwt[stytch_claim]['attributes'], - 'authentication_factors' => jwt[stytch_claim]['authentication_factors'], - 'custom_claims' => custom_claims + 'member_session' => { + 'session_id' => jwt[stytch_claim]['id'], + 'organization_id' => jwt[organization_claim]['id'], + 'member_id' => jwt['sub'], + # DEPRECATED: Use member_id instead + 'user_id' => jwt['sub'], + 'started_at' => jwt[stytch_claim]['started_at'], + 'last_accessed_at' => jwt[stytch_claim]['last_accessed_at'], + # For JWTs that include it, prefer the inner expires_at claim. + 'expires_at' => expires_at, + 'attributes' => jwt[stytch_claim]['attributes'], + 'authentication_factors' => jwt[stytch_claim]['authentication_factors'], + 'custom_claims' => custom_claims + }, + 'roles' => jwt[roles_claim] } end # ENDMANUAL(Sessions::authenticate_jwt) diff --git a/lib/stytch/client.rb b/lib/stytch/client.rb index d7b315d..9b42f8b 100644 --- a/lib/stytch/client.rb +++ b/lib/stytch/client.rb @@ -25,12 +25,12 @@ def initialize(project_id:, secret:, env: nil, &block) create_connection(&block) @crypto_wallets = Stytch::CryptoWallets.new(@connection) - @m2m = Stytch::M2M.new(@connection, project_id) + @m2m = Stytch::M2M.new(@connection, @project_id) @magic_links = Stytch::MagicLinks.new(@connection) @oauth = Stytch::OAuth.new(@connection) @otps = Stytch::OTPs.new(@connection) @passwords = Stytch::Passwords.new(@connection) - @sessions = Stytch::Sessions.new(@connection, project_id) + @sessions = Stytch::Sessions.new(@connection, @project_id) @totps = Stytch::TOTPs.new(@connection) @users = Stytch::Users.new(@connection) @webauthn = Stytch::WebAuthn.new(@connection) diff --git a/lib/stytch/errors.rb b/lib/stytch/errors.rb index acab36e..1ffdc02 100644 --- a/lib/stytch/errors.rb +++ b/lib/stytch/errors.rb @@ -35,4 +35,18 @@ def initialize(scope) super(msg) end end + + class TenancyError < StandardError + def initialize(subject_org_id, request_org_id) + msg = "Subject organization_id #{subject_org_id} does not match authZ request organization_id #{request_org_id}" + super(msg) + end + end + + class PermissionError < StandardError + def initialize(request) + msg = "Permission denied for request #{request}" + super(msg) + end + end end diff --git a/lib/stytch/rbac_local.rb b/lib/stytch/rbac_local.rb new file mode 100644 index 0000000..63916f3 --- /dev/null +++ b/lib/stytch/rbac_local.rb @@ -0,0 +1,56 @@ +# frozen_string_literal: true + +require_relative 'request_helper' + +module StytchB2B + class PolicyCache + def initialize(rbac_client:) + @rbac_client = rbac_client + @policy_last_update = 0 + @cached_policy = nil + end + + def reload_policy + @cached_policy = rbac_client.get_policy + @policy_last_update = Time.now.to_i + end + + def get_policy(invalidate: false) + reload_policy if invalidate || @cached_policy.nil? || @policy_last_update < Time.now.to_i - 300 + @cached_policy + end + + # Performs an authorization check against the project's policy and a set of roles. If the + # check succeeds, this method will return. If the check fails, a PermissionError + # will be raised. It's also possible for a TenancyError to be raised if the + # subject_org_id does not match the authZ request organization_id. + # authorization_check is an object with keys 'action', 'resource_id', and 'organization_id' + def perform_authorization_check( + subject_roles:, + subject_org_id:, + authorization_check: + ) + raise TenancyError, subject_org_id if subject_org_id != authorization_check['organization_id'] + + policy = get_policy + + for role in policy['roles'] + next unless subject_roles.include?(role['role_id']) + + for permission in role['permissions'] + actions = permission['actions'] + resource = permission['resource_id'] + has_matching_action = actions.include?('*') || actions.include?(authorization_check['action']) + has_matching_resource = resource == authorization_check['resource_id'] + if has_matching_action && has_matching_resource + # All good + return + end + end + end + + # If we get here, we didn't find a matching permission + raise PermissionError, authorization_check + end + end +end