Skip to content

Commit

Permalink
refactor: trait RpcMethod { const PERMISSION: ... } (#4277)
Browse files Browse the repository at this point in the history
  • Loading branch information
aatifsyed authored Apr 30, 2024
1 parent f3f84f5 commit 36d094f
Show file tree
Hide file tree
Showing 15 changed files with 183 additions and 182 deletions.
206 changes: 39 additions & 167 deletions src/rpc/auth_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
use crate::auth::{verify_token, JWT_IDENTIFIER};
use crate::key_management::KeyStore;
use crate::rpc::{
auth, beacon, chain, common, eth, gas, mpool, net, node, state, sync, wallet, RpcMethod as _,
CANCEL_METHOD_NAME,
auth, beacon, chain, common, eth, gas, mpool, net, node, state, sync, wallet, Permission,
RpcMethod as _, CANCEL_METHOD_NAME,
};
use ahash::{HashMap, HashMapExt as _};
use futures::future::BoxFuture;
Expand All @@ -21,170 +21,42 @@ use tokio::sync::RwLock;
use tower::Layer;
use tracing::debug;

/// Access levels to be checked against JWT claims
enum Access {
Admin,
Sign,
Write,
Read,
}

/// Access mapping between method names and access levels
/// Checked against JWT claims on every request
static ACCESS_MAP: Lazy<HashMap<&str, Access>> = Lazy::new(|| {
static METHOD_NAME2REQUIRED_PERMISSION: Lazy<HashMap<&str, Permission>> = Lazy::new(|| {
let mut access = HashMap::new();

// Auth API
access.insert(auth::AuthNew::NAME, Access::Admin);
access.insert(auth::AuthVerify::NAME, Access::Read);

// Beacon API
access.insert(beacon::BeaconGetEntry::NAME, Access::Read);

// Chain API
access.insert(chain::ChainGetMessage::NAME, Access::Read);
access.insert(chain::ChainExport::NAME, Access::Read);
access.insert(chain::ChainReadObj::NAME, Access::Read);
access.insert(chain::ChainGetPath::NAME, Access::Read);
access.insert(chain::ChainHasObj::NAME, Access::Read);
access.insert(chain::ChainGetBlockMessages::NAME, Access::Read);
access.insert(chain::ChainGetTipSetByHeight::NAME, Access::Read);
access.insert(chain::ChainGetTipSetAfterHeight::NAME, Access::Read);
access.insert(chain::ChainGetGenesis::NAME, Access::Read);
access.insert(chain::ChainHead::NAME, Access::Read);
access.insert(chain::ChainGetBlock::NAME, Access::Read);
access.insert(chain::ChainGetTipSet::NAME, Access::Read);
access.insert(chain::ChainSetHead::NAME, Access::Admin);
access.insert(chain::ChainGetMinBaseFee::NAME, Access::Admin);
access.insert(chain::ChainTipSetWeight::NAME, Access::Read);
access.insert(chain::ChainGetMessagesInTipset::NAME, Access::Read);
access.insert(chain::ChainGetParentMessages::NAME, Access::Read);
access.insert(chain::CHAIN_NOTIFY, Access::Read);
access.insert(chain::ChainGetParentReceipts::NAME, Access::Read);

// Message Pool API
access.insert(mpool::MpoolGetNonce::NAME, Access::Read);
access.insert(mpool::MpoolPending::NAME, Access::Read);
access.insert(mpool::MpoolSelect::NAME, Access::Read);
// Lotus limits `MPOOL_PUSH`` to `Access::Write`. However, since messages
// can always be pushed over the p2p protocol, limiting the RPC doesn't
// improve security.
access.insert(mpool::MpoolPush::NAME, Access::Read);
access.insert(mpool::MpoolPushMessage::NAME, Access::Sign);

// Sync API
access.insert(sync::SyncCheckBad::NAME, Access::Read);
access.insert(sync::SyncMarkBad::NAME, Access::Admin);
access.insert(sync::SyncState::NAME, Access::Read);
access.insert(sync::SyncSubmitBlock::NAME, Access::Write);

// Wallet API
access.insert(wallet::WalletBalance::NAME, Access::Read);
access.insert(wallet::WalletDefaultAddress::NAME, Access::Read);
access.insert(wallet::WalletExport::NAME, Access::Admin);
access.insert(wallet::WalletHas::NAME, Access::Write);
access.insert(wallet::WalletImport::NAME, Access::Admin);
access.insert(wallet::WalletList::NAME, Access::Write);
access.insert(wallet::WalletNew::NAME, Access::Write);
access.insert(wallet::WalletSetDefault::NAME, Access::Write);
access.insert(wallet::WalletSign::NAME, Access::Sign);
access.insert(wallet::WalletValidateAddress::NAME, Access::Read);
access.insert(wallet::WalletVerify::NAME, Access::Read);
access.insert(wallet::WalletDelete::NAME, Access::Write);

// State API
access.insert(state::MinerGetBaseInfo::NAME, Access::Read);
access.insert(state::StateCall::NAME, Access::Read);
access.insert(state::StateNetworkName::NAME, Access::Read);
access.insert(state::StateReplay::NAME, Access::Read);
access.insert(state::StateGetActor::NAME, Access::Read);
access.insert(state::StateMarketBalance::NAME, Access::Read);
access.insert(state::StateMarketDeals::NAME, Access::Read);
access.insert(state::StateMinerInfo::NAME, Access::Read);
access.insert(state::StateMinerActiveSectors::NAME, Access::Read);
access.insert(state::StateMinerFaults::NAME, Access::Read);
access.insert(state::StateMinerRecoveries::NAME, Access::Read);
access.insert(state::StateMinerPower::NAME, Access::Read);
access.insert(state::StateMinerDeadlines::NAME, Access::Read);
access.insert(state::StateMinerProvingDeadline::NAME, Access::Read);
access.insert(state::StateMinerAvailableBalance::NAME, Access::Read);
access.insert(state::StateGetReceipt::NAME, Access::Read);
access.insert(state::StateWaitMsg::NAME, Access::Read);
access.insert(state::StateSearchMsg::NAME, Access::Read);
access.insert(state::StateSearchMsgLimited::NAME, Access::Read);
access.insert(state::StateNetworkVersion::NAME, Access::Read);
access.insert(state::StateAccountKey::NAME, Access::Read);
access.insert(state::StateLookupID::NAME, Access::Read);
access.insert(state::StateFetchRoot::NAME, Access::Read);
access.insert(state::StateGetRandomnessFromTickets::NAME, Access::Read);
access.insert(state::StateGetRandomnessFromBeacon::NAME, Access::Read);
access.insert(state::StateReadState::NAME, Access::Read);
access.insert(state::StateCirculatingSupply::NAME, Access::Read);
access.insert(state::StateSectorGetInfo::NAME, Access::Read);
access.insert(state::StateListMessages::NAME, Access::Read);
access.insert(state::StateListMiners::NAME, Access::Read);
access.insert(state::StateMinerSectorCount::NAME, Access::Read);
access.insert(state::StateMinerSectors::NAME, Access::Read);
access.insert(state::StateMinerPartitions::NAME, Access::Read);
access.insert(state::StateVerifiedClientStatus::NAME, Access::Read);
access.insert(state::StateMarketStorageDeal::NAME, Access::Read);
access.insert(state::StateVMCirculatingSupplyInternal::NAME, Access::Read);
access.insert(state::MsigGetAvailableBalance::NAME, Access::Read);
access.insert(state::MsigGetPending::NAME, Access::Read);
access.insert(state::StateDealProviderCollateralBounds::NAME, Access::Read);
access.insert(state::StateGetBeaconEntry::NAME, Access::Read);
access.insert(state::StateSectorPreCommitInfo::NAME, Access::Read);

// Gas API
access.insert(gas::GasEstimateGasLimit::NAME, Access::Read);
access.insert(gas::GasEstimateGasPremium::NAME, Access::Read);
access.insert(gas::GasEstimateFeeCap::NAME, Access::Read);
access.insert(gas::GasEstimateMessageGas::NAME, Access::Read);

// Common API
access.insert(common::Version::NAME, Access::Read);
access.insert(common::Session::NAME, Access::Read);
access.insert(common::Shutdown::NAME, Access::Admin);
access.insert(common::StartTime::NAME, Access::Read);

// Net API
access.insert(net::NetAddrsListen::NAME, Access::Read);
access.insert(net::NetPeers::NAME, Access::Read);
access.insert(net::NetListening::NAME, Access::Read);
access.insert(net::NetInfo::NAME, Access::Read);
access.insert(net::NetConnect::NAME, Access::Write);
access.insert(net::NetDisconnect::NAME, Access::Write);
access.insert(net::NetAgentVersion::NAME, Access::Read);
access.insert(net::NetAutoNatStatus::NAME, Access::Read);
access.insert(net::NetVersion::NAME, Access::Read);

// Node API
access.insert(node::NodeStatus::NAME, Access::Read);

// Eth API
access.insert(eth::EthAccounts::NAME, Access::Read);
access.insert(eth::EthBlockNumber::NAME, Access::Read);
access.insert(eth::EthChainId::NAME, Access::Read);
access.insert(eth::EthGasPrice::NAME, Access::Read);
access.insert(eth::EthGetBalance::NAME, Access::Read);
access.insert(eth::EthSyncing::NAME, Access::Read);
access.insert(eth::EthGetBlockByNumber::NAME, Access::Read);
access.insert(eth::Web3ClientVersion::NAME, Access::Read);

// Pubsub API
access.insert(CANCEL_METHOD_NAME, Access::Read);
macro_rules! insert {
($ty:ty) => {
access.insert(<$ty>::NAME, <$ty>::PERMISSION);
};
}

auth::for_each_method!(insert);
beacon::for_each_method!(insert);
chain::for_each_method!(insert);
common::for_each_method!(insert);
gas::for_each_method!(insert);
mpool::for_each_method!(insert);
net::for_each_method!(insert);
state::for_each_method!(insert);
node::for_each_method!(insert);
sync::for_each_method!(insert);
wallet::for_each_method!(insert);
eth::for_each_method!(insert);

access.insert(chain::CHAIN_NOTIFY, Permission::Read);
access.insert(CANCEL_METHOD_NAME, Permission::Read);

access
});

/// Checks an access enumeration against provided JWT claims
fn check_access(access: &Access, claims: &[String]) -> bool {
match access {
Access::Admin => claims.contains(&"admin".to_owned()),
Access::Sign => claims.contains(&"sign".to_owned()),
Access::Write => claims.contains(&"write".to_owned()),
Access::Read => claims.contains(&"read".to_owned()),
}
fn is_allowed(required_by_method: Permission, claimed_by_user: &[String]) -> bool {
let needle = match required_by_method {
Permission::Admin => "admin",
Permission::Sign => "sign",
Permission::Write => "write",
Permission::Read => "read",
};
claimed_by_user.iter().any(|haystack| haystack == needle)
}

#[derive(Clone)]
Expand All @@ -194,10 +66,10 @@ pub struct AuthLayer {
}

impl<S> Layer<S> for AuthLayer {
type Service = AuthMiddleware<S>;
type Service = Auth<S>;

fn layer(&self, service: S) -> Self::Service {
AuthMiddleware {
Auth {
headers: self.headers.clone(),
keystore: self.keystore.clone(),
service,
Expand All @@ -206,13 +78,13 @@ impl<S> Layer<S> for AuthLayer {
}

#[derive(Clone)]
pub struct AuthMiddleware<S> {
pub struct Auth<S> {
headers: HeaderMap,
keystore: Arc<RwLock<KeyStore>>,
service: S,
}

impl<'a, S> RpcServiceT<'a> for AuthMiddleware<S>
impl<'a, S> RpcServiceT<'a> for Auth<S>
where
S: RpcServiceT<'a> + Send + Sync + Clone + 'static,
{
Expand Down Expand Up @@ -267,9 +139,9 @@ async fn check_permissions(
};
debug!("Decoded JWT Claims: {}", claims.join(","));

match ACCESS_MAP.get(&method) {
Some(access) => {
if check_access(access, &claims) {
match METHOD_NAME2REQUIRED_PERMISSION.get(&method) {
Some(required_by_method) => {
if is_allowed(*required_by_method, &claims) {
Ok(())
} else {
Err(ErrorCode::InvalidRequest)
Expand Down
4 changes: 3 additions & 1 deletion src/rpc/methods/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::auth::*;
use crate::lotus_json::lotus_json_with_self;
use crate::rpc::{ApiVersion, Ctx, RpcMethod, ServerError};
use crate::rpc::{ApiVersion, Ctx, Permission, RpcMethod, ServerError};
use anyhow::Result;
use chrono::Duration;
use fvm_ipld_blockstore::Blockstore;
Expand All @@ -25,6 +25,7 @@ impl RpcMethod<1> for AuthNew {
const NAME: &'static str = "Filecoin.AuthNew";
const PARAM_NAMES: [&'static str; 1] = ["params"];
const API_VERSION: ApiVersion = ApiVersion::V0;
const PERMISSION: Permission = Permission::Admin;
type Params = (AuthNewParams,);
type Ok = Vec<u8>;
async fn handle(
Expand All @@ -43,6 +44,7 @@ impl RpcMethod<1> for AuthVerify {
const NAME: &'static str = "Filecoin.AuthVerify";
const PARAM_NAMES: [&'static str; 1] = ["header_raw"];
const API_VERSION: ApiVersion = ApiVersion::V0;
const PERMISSION: Permission = Permission::Read;
type Params = (String,);
type Ok = Vec<String>;
async fn handle(
Expand Down
3 changes: 2 additions & 1 deletion src/rpc/methods/beacon.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2019-2024 ChainSafe Systems
// SPDX-License-Identifier: Apache-2.0, MIT

use crate::rpc::{ApiVersion, Ctx, RpcMethod, ServerError};
use crate::rpc::{ApiVersion, Ctx, Permission, RpcMethod, ServerError};
use crate::{beacon::BeaconEntry, shim::clock::ChainEpoch};
use anyhow::Result;
use fvm_ipld_blockstore::Blockstore;
Expand All @@ -21,6 +21,7 @@ impl RpcMethod<1> for BeaconGetEntry {
const NAME: &'static str = "Filecoin.BeaconGetEntry";
const PARAM_NAMES: [&'static str; 1] = ["first"];
const API_VERSION: ApiVersion = ApiVersion::V0;
const PERMISSION: Permission = Permission::Read;

type Params = (ChainEpoch,);
type Ok = BeaconEntry;
Expand Down
Loading

0 comments on commit 36d094f

Please sign in to comment.