Skip to content

Commit

Permalink
refactor: auth_layer
Browse files Browse the repository at this point in the history
  • Loading branch information
aatifsyed committed Apr 30, 2024
1 parent 0b61c97 commit 1fffff5
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 165 deletions.
199 changes: 39 additions & 160 deletions src/rpc/auth_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
use crate::auth::{verify_token, JWT_IDENTIFIER};
use crate::key_management::KeyStore;
use crate::rpc::{
auth, beacon, chain, common, eth, gas, mpool, net, node, state, sync, wallet, RpcMethod as _,
CANCEL_METHOD_NAME,
auth, beacon, chain, common, eth, gas, mpool, net, node, state, sync, wallet, Permission,
RpcMethod as _, CANCEL_METHOD_NAME,
};
use ahash::{HashMap, HashMapExt as _};
use futures::future::BoxFuture;
Expand All @@ -21,163 +21,42 @@ use tokio::sync::RwLock;
use tower::Layer;
use tracing::debug;

/// Access levels to be checked against JWT claims
enum Access {
Admin,
Sign,
Write,
Read,
}

/// Access mapping between method names and access levels
/// Checked against JWT claims on every request
static ACCESS_MAP: Lazy<HashMap<&str, Access>> = Lazy::new(|| {
static METHOD_NAME2REQUIRED_PERMISSION: Lazy<HashMap<&str, Permission>> = Lazy::new(|| {
let mut access = HashMap::new();

// Chain API
access.insert(chain::ChainGetMessage::NAME, Access::Read);
access.insert(chain::ChainExport::NAME, Access::Read);
access.insert(chain::ChainReadObj::NAME, Access::Read);
access.insert(chain::ChainGetPath::NAME, Access::Read);
access.insert(chain::ChainHasObj::NAME, Access::Read);
access.insert(chain::ChainGetBlockMessages::NAME, Access::Read);
access.insert(chain::ChainGetTipSetByHeight::NAME, Access::Read);
access.insert(chain::ChainGetTipSetAfterHeight::NAME, Access::Read);
access.insert(chain::ChainGetGenesis::NAME, Access::Read);
access.insert(chain::ChainHead::NAME, Access::Read);
access.insert(chain::ChainGetBlock::NAME, Access::Read);
access.insert(chain::ChainGetTipSet::NAME, Access::Read);
access.insert(chain::ChainSetHead::NAME, Access::Admin);
access.insert(chain::ChainGetMinBaseFee::NAME, Access::Admin);
access.insert(chain::ChainTipSetWeight::NAME, Access::Read);
access.insert(chain::ChainGetMessagesInTipset::NAME, Access::Read);
access.insert(chain::ChainGetParentMessages::NAME, Access::Read);
access.insert(chain::CHAIN_NOTIFY, Access::Read);
access.insert(chain::ChainGetParentReceipts::NAME, Access::Read);

// Message Pool API
access.insert(mpool::MpoolGetNonce::NAME, Access::Read);
access.insert(mpool::MpoolPending::NAME, Access::Read);
access.insert(mpool::MpoolSelect::NAME, Access::Read);
// Lotus limits `MPOOL_PUSH`` to `Access::Write`. However, since messages
// can always be pushed over the p2p protocol, limiting the RPC doesn't
// improve security.
access.insert(mpool::MpoolPush::NAME, Access::Read);
access.insert(mpool::MpoolPushMessage::NAME, Access::Sign);

// Sync API
access.insert(sync::SyncCheckBad::NAME, Access::Read);
access.insert(sync::SyncMarkBad::NAME, Access::Admin);
access.insert(sync::SyncState::NAME, Access::Read);
access.insert(sync::SyncSubmitBlock::NAME, Access::Write);

// Wallet API
access.insert(wallet::WalletBalance::NAME, Access::Read);
access.insert(wallet::WalletDefaultAddress::NAME, Access::Read);
access.insert(wallet::WalletExport::NAME, Access::Admin);
access.insert(wallet::WalletHas::NAME, Access::Write);
access.insert(wallet::WalletImport::NAME, Access::Admin);
access.insert(wallet::WalletList::NAME, Access::Write);
access.insert(wallet::WalletNew::NAME, Access::Write);
access.insert(wallet::WalletSetDefault::NAME, Access::Write);
access.insert(wallet::WalletSign::NAME, Access::Sign);
access.insert(wallet::WalletValidateAddress::NAME, Access::Read);
access.insert(wallet::WalletVerify::NAME, Access::Read);
access.insert(wallet::WalletDelete::NAME, Access::Write);

// State API
access.insert(state::MinerGetBaseInfo::NAME, Access::Read);
access.insert(state::StateCall::NAME, Access::Read);
access.insert(state::StateNetworkName::NAME, Access::Read);
access.insert(state::StateReplay::NAME, Access::Read);
access.insert(state::StateGetActor::NAME, Access::Read);
access.insert(state::StateMarketBalance::NAME, Access::Read);
access.insert(state::StateMarketDeals::NAME, Access::Read);
access.insert(state::StateMinerInfo::NAME, Access::Read);
access.insert(state::StateMinerActiveSectors::NAME, Access::Read);
access.insert(state::StateMinerFaults::NAME, Access::Read);
access.insert(state::StateMinerRecoveries::NAME, Access::Read);
access.insert(state::StateMinerPower::NAME, Access::Read);
access.insert(state::StateMinerDeadlines::NAME, Access::Read);
access.insert(state::StateMinerProvingDeadline::NAME, Access::Read);
access.insert(state::StateMinerAvailableBalance::NAME, Access::Read);
access.insert(state::StateGetReceipt::NAME, Access::Read);
access.insert(state::StateWaitMsg::NAME, Access::Read);
access.insert(state::StateSearchMsg::NAME, Access::Read);
access.insert(state::StateSearchMsgLimited::NAME, Access::Read);
access.insert(state::StateNetworkVersion::NAME, Access::Read);
access.insert(state::StateAccountKey::NAME, Access::Read);
access.insert(state::StateLookupID::NAME, Access::Read);
access.insert(state::StateFetchRoot::NAME, Access::Read);
access.insert(state::StateGetRandomnessFromTickets::NAME, Access::Read);
access.insert(state::StateGetRandomnessFromBeacon::NAME, Access::Read);
access.insert(state::StateReadState::NAME, Access::Read);
access.insert(state::StateCirculatingSupply::NAME, Access::Read);
access.insert(state::StateSectorGetInfo::NAME, Access::Read);
access.insert(state::StateListMessages::NAME, Access::Read);
access.insert(state::StateListMiners::NAME, Access::Read);
access.insert(state::StateMinerSectorCount::NAME, Access::Read);
access.insert(state::StateMinerSectors::NAME, Access::Read);
access.insert(state::StateMinerPartitions::NAME, Access::Read);
access.insert(state::StateVerifiedClientStatus::NAME, Access::Read);
access.insert(state::StateMarketStorageDeal::NAME, Access::Read);
access.insert(state::StateVMCirculatingSupplyInternal::NAME, Access::Read);
access.insert(state::MsigGetAvailableBalance::NAME, Access::Read);
access.insert(state::MsigGetPending::NAME, Access::Read);
access.insert(state::StateDealProviderCollateralBounds::NAME, Access::Read);
access.insert(state::StateGetBeaconEntry::NAME, Access::Read);
access.insert(state::StateSectorPreCommitInfo::NAME, Access::Read);

// Gas API
access.insert(gas::GasEstimateGasLimit::NAME, Access::Read);
access.insert(gas::GasEstimateGasPremium::NAME, Access::Read);
access.insert(gas::GasEstimateFeeCap::NAME, Access::Read);
access.insert(gas::GasEstimateMessageGas::NAME, Access::Read);

// Common API
access.insert(common::Version::NAME, Access::Read);
access.insert(common::Session::NAME, Access::Read);
access.insert(common::Shutdown::NAME, Access::Admin);
access.insert(common::StartTime::NAME, Access::Read);

// Net API
access.insert(net::NetAddrsListen::NAME, Access::Read);
access.insert(net::NetPeers::NAME, Access::Read);
access.insert(net::NetListening::NAME, Access::Read);
access.insert(net::NetInfo::NAME, Access::Read);
access.insert(net::NetConnect::NAME, Access::Write);
access.insert(net::NetDisconnect::NAME, Access::Write);
access.insert(net::NetAgentVersion::NAME, Access::Read);
access.insert(net::NetAutoNatStatus::NAME, Access::Read);
access.insert(net::NetVersion::NAME, Access::Read);

// Node API
access.insert(node::NodeStatus::NAME, Access::Read);

// Eth API
access.insert(eth::EthAccounts::NAME, Access::Read);
access.insert(eth::EthBlockNumber::NAME, Access::Read);
access.insert(eth::EthChainId::NAME, Access::Read);
access.insert(eth::EthGasPrice::NAME, Access::Read);
access.insert(eth::EthGetBalance::NAME, Access::Read);
access.insert(eth::EthSyncing::NAME, Access::Read);
access.insert(eth::EthGetBlockByNumber::NAME, Access::Read);
access.insert(eth::Web3ClientVersion::NAME, Access::Read);

// Pubsub API
access.insert(CANCEL_METHOD_NAME, Access::Read);
macro_rules! insert {
($ty:ty) => {
access.insert(<$ty>::NAME, <$ty>::PERMISSION);
};
}

auth::for_each_method!(insert);
beacon::for_each_method!(insert);
chain::for_each_method!(insert);
common::for_each_method!(insert);
gas::for_each_method!(insert);
mpool::for_each_method!(insert);
net::for_each_method!(insert);
state::for_each_method!(insert);
node::for_each_method!(insert);
sync::for_each_method!(insert);
wallet::for_each_method!(insert);
eth::for_each_method!(insert);

access.insert(chain::CHAIN_NOTIFY, Permission::Read);
access.insert(CANCEL_METHOD_NAME, Permission::Read);

access
});

/// Checks an access enumeration against provided JWT claims
fn check_access(access: &Access, claims: &[String]) -> bool {
match access {
Access::Admin => claims.contains(&"admin".to_owned()),
Access::Sign => claims.contains(&"sign".to_owned()),
Access::Write => claims.contains(&"write".to_owned()),
Access::Read => claims.contains(&"read".to_owned()),
}
fn is_allowed(required_by_method: Permission, claimed_by_user: &[String]) -> bool {
let needle = match required_by_method {
Permission::Admin => "admin",
Permission::Sign => "sign",
Permission::Write => "write",
Permission::Read => "read",
};
claimed_by_user.iter().any(|haystack| haystack == needle)
}

#[derive(Clone)]
Expand All @@ -187,10 +66,10 @@ pub struct AuthLayer {
}

impl<S> Layer<S> for AuthLayer {
type Service = AuthMiddleware<S>;
type Service = Auth<S>;

fn layer(&self, service: S) -> Self::Service {
AuthMiddleware {
Auth {
headers: self.headers.clone(),
keystore: self.keystore.clone(),
service,
Expand All @@ -199,13 +78,13 @@ impl<S> Layer<S> for AuthLayer {
}

#[derive(Clone)]
pub struct AuthMiddleware<S> {
pub struct Auth<S> {
headers: HeaderMap,
keystore: Arc<RwLock<KeyStore>>,
service: S,
}

impl<'a, S> RpcServiceT<'a> for AuthMiddleware<S>
impl<'a, S> RpcServiceT<'a> for Auth<S>
where
S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
{
Expand Down Expand Up @@ -260,9 +139,9 @@ async fn check_permissions(
};
debug!("Decoded JWT Claims: {}", claims.join(","));

match ACCESS_MAP.get(&method) {
Some(access) => {
if check_access(access, &claims) {
match METHOD_NAME2REQUIRED_PERMISSION.get(&method) {
Some(required_by_method) => {
if is_allowed(*required_by_method, &claims) {
Ok(())
} else {
Err(ErrorCode::InvalidRequest)
Expand Down
13 changes: 9 additions & 4 deletions src/rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod error;
mod reflect;
pub mod types;
pub use methods::*;
pub(self) use reflect::Permission;
use reflect::Permission;

/// Protocol or transport-specific error
#[allow(unused)]
Expand Down Expand Up @@ -66,10 +66,8 @@ pub mod prelude {
/// - If it is used _across_ API verticals, it should live in `src/rpc/types.rs`
///
/// # Interactions with the [`lotus_json`] APIs
/// - Types defined in the module will only ever be deserialized as JSON, so there
/// will NEVER be a need to implement [`HasLotusJson`] for them.
/// - Types may have fields which must go through [`LotusJson`],
/// and must reflect that in their [`JsonSchema`].
/// and MUST reflect that in their [`JsonSchema`].
/// You have two options for this:
/// - Use `#[attributes]` to control serialization and schema generation:
/// ```ignore
Expand All @@ -89,6 +87,13 @@ pub mod prelude {
/// }
/// ```
///
/// # `for_each_method`
/// Each API vertical exposes a [`for_each_method!`](auth::for_each_method) macro,
/// which is used in three places:
/// - [`prelude`], where all the methods are exported for use in the codebase.
/// - [`auth_layer`], where their [`RpcMethod::PERMISSION`]s are registered.
/// - [`create_module`], where they're actually registered to be served.
///
/// [`lotus_json`]: crate::lotus_json
/// [`HasLotusJson`]: crate::lotus_json::HasLotusJson
/// [`LotusJson`]: crate::lotus_json::LotusJson
Expand Down
2 changes: 1 addition & 1 deletion src/rpc/reflect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub trait RpcMethod<const ARITY: usize> {
/// See [`ApiVersion`].
const API_VERSION: ApiVersion;
/// See [`Permission`]
const PERMISSION: Permission = Permission::Read; // TODO(aatifsyed): removeme
const PERMISSION: Permission;
/// Types of each argument. [`Option`]-al arguments MUST follow mandatory ones.
type Params: Params<ARITY>;
/// Return value of this method.
Expand Down

0 comments on commit 1fffff5

Please sign in to comment.