Skip to content

Commit

Permalink
feat: network CRUD
Browse files Browse the repository at this point in the history
  • Loading branch information
JettChenT committed Jul 22, 2024
1 parent edd796e commit 8d1f0a9
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 19 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ target/

.DS_STORE
.idea/

tmp/
Binary file modified burrow.db
Binary file not shown.
5 changes: 3 additions & 2 deletions burrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ tokio = { version = "1.37", features = [
"signal",
"time",
"tracing",
"fs",
] }
tun = { version = "0.1", path = "../tun", features = ["serde", "tokio"] }
clap = { version = "4.4", features = ["derive"] }
Expand Down Expand Up @@ -56,7 +57,7 @@ reqwest = { version = "0.12", default-features = false, features = [
"json",
"rustls-tls",
] }
rusqlite = "0.31.0"
rusqlite = { version = "0.31.0", features = ["blob"] }
dotenv = "0.15.0"
tonic = "0.12.0"
prost = "0.13.1"
Expand All @@ -73,7 +74,7 @@ tracing-journald = "0.3"

[target.'cfg(target_vendor = "apple")'.dependencies]
nix = { version = "0.27" }
rusqlite = { version = "0.31.0", features = ["bundled"] }
rusqlite = { version = "0.31.0", features = ["bundled", "blob"] }

[dev-dependencies]
insta = { version = "1.32", features = ["yaml"] }
Expand Down
39 changes: 33 additions & 6 deletions burrow/src/daemon/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ use crate::{
ServerConfig,
ServerInfo,
},
database::{delete_network, get_connection, list_networks, load_interface, reorder_network},
database::{
add_network,
delete_network,
get_connection,
list_networks,
load_interface,
reorder_network,
},
wireguard::{Config, Interface},
};

Expand Down Expand Up @@ -164,6 +171,7 @@ pub struct DaemonRPCServer {
config: Arc<RwLock<Config>>,
db_path: Option<PathBuf>,
wg_state_chan: (watch::Sender<RunState>, watch::Receiver<RunState>),
network_update_chan: (watch::Sender<()>, watch::Receiver<()>),
}

impl DaemonRPCServer {
Expand All @@ -178,6 +186,7 @@ impl DaemonRPCServer {
config,
db_path: db_path.map(|p| p.to_owned()),
wg_state_chan: watch::channel(RunState::Idle),
network_update_chan: watch::channel(()),
})
}

Expand All @@ -192,6 +201,10 @@ impl DaemonRPCServer {
async fn get_wg_state(&self) -> RunState {
self.wg_state_chan.1.borrow().to_owned()
}

async fn notify_network_update(&self) -> Result<(), RspStatus> {
self.network_update_chan.0.send(()).map_err(proc_err)
}
}

#[tonic::async_trait]
Expand Down Expand Up @@ -267,7 +280,11 @@ impl Tunnel for DaemonRPCServer {
loop {
state_rx.changed().await.unwrap();
let cur = state_rx.borrow().to_owned();
tx.send(Ok(status_rsp(cur))).await;
let res = tx.send(Ok(status_rsp(cur))).await;
if res.is_err() {
eprintln!("Tunnel status channel closed");
break;
}
}
});
Ok(Response::new(ReceiverStream::new(rx)))
Expand All @@ -278,8 +295,11 @@ impl Tunnel for DaemonRPCServer {
impl Networks for DaemonRPCServer {
type NetworkListStream = ReceiverStream<Result<NetworkListResponse, RspStatus>>;

async fn network_add(&self, _request: Request<Network>) -> Result<Response<Empty>, RspStatus> {
debug!("Mock network_add called");
async fn network_add(&self, request: Request<Network>) -> Result<Response<Empty>, RspStatus> {
let conn = self.get_connection()?;
let network = request.into_inner();
add_network(&conn, &network).map_err(proc_err)?;
self.notify_network_update().await?;
Ok(Response::new(Empty {}))
}

Expand All @@ -290,13 +310,18 @@ impl Networks for DaemonRPCServer {
debug!("Mock network_list called");
let (tx, rx) = mpsc::channel(10);
let conn = self.get_connection()?;
let mut sub = self.network_update_chan.1.clone();
tokio::spawn(async move {
loop {
let networks = list_networks(&conn)
.map(|res| NetworkListResponse { network: res })
.map_err(proc_err);
tx.send(networks).await.unwrap();
tokio::time::sleep(Duration::from_secs(10)).await;
let res = tx.send(networks).await;
if res.is_err() {
eprintln!("Network list channel closed");
break;
}
sub.changed().await.unwrap();
}
});
Ok(Response::new(ReceiverStream::new(rx)))
Expand All @@ -308,6 +333,7 @@ impl Networks for DaemonRPCServer {
) -> Result<Response<Empty>, RspStatus> {
let conn = self.get_connection()?;
reorder_network(&conn, request.into_inner()).map_err(proc_err)?;
self.notify_network_update().await?;
Ok(Response::new(Empty {}))
}

Expand All @@ -317,6 +343,7 @@ impl Networks for DaemonRPCServer {
) -> Result<Response<Empty>, RspStatus> {
let conn = self.get_connection()?;
delete_network(&conn, request.into_inner()).map_err(proc_err)?;
self.notify_network_update().await?;
Ok(Response::new(Empty {}))
}
}
Expand Down
31 changes: 22 additions & 9 deletions burrow/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@ const CREATE_WG_PEER_TABLE: &str = "CREATE TABLE IF NOT EXISTS wg_peer (
const CREATE_NETWORK_TABLE: &str = "CREATE TABLE IF NOT EXISTS network (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT NOT NULL,
raw_payload TEXT,
index INTEGER AUTOINCREMENT,
interface_id INT REFERENCES wg_interface(id) ON UPDATE CASCADE,
)";
payload BLOB,
idx INTEGER,
interface_id INT REFERENCES wg_interface(id) ON UPDATE CASCADE
);
CREATE TRIGGER IF NOT EXISTS increment_network_idx
AFTER INSERT ON network
BEGIN
UPDATE network
SET idx = (SELECT COALESCE(MAX(idx), 0) + 1 FROM network)
WHERE id = NEW.id;
END;
";

pub fn initialize_tables(conn: &Connection) -> Result<()> {
conn.execute(CREATE_WG_INTERFACE_TABLE, [])?;
Expand Down Expand Up @@ -126,21 +134,26 @@ pub fn get_connection(path: Option<&Path>) -> Result<Connection> {
}

pub fn add_network(conn: &Connection, network: &RPCNetwork) -> Result<()> {
let mut stmt = conn.prepare("INSERT INTO network (type, payload) VALUES (?, ?)")?;
stmt.execute(params![network.r#type().as_str_name(), &network.payload])?;
let mut stmt = conn.prepare("INSERT INTO network (id, type, payload) VALUES (?, ?, ?)")?;
stmt.execute(params![
network.id,
network.r#type().as_str_name(),
&network.payload
])?;
// TODO: if the type is Wireguard, add the corresponding neetwork interface and then link it
Ok(())
}

pub fn list_networks(conn: &Connection) -> Result<Vec<RPCNetwork>> {
let mut stmt = conn.prepare("SELECT id, type, payload FROM network ORDER BY id")?;
let mut stmt = conn.prepare("SELECT id, type, payload FROM network ORDER BY idx")?;
let networks: Vec<RPCNetwork> = stmt
.query_map([], |row| {
println!("row: {:?}", row);
let network_id: i32 = row.get(0)?;
let network_type: String = row.get(1)?;
let network_type = NetworkType::from_str_name(network_type.as_str())
.ok_or(rusqlite::Error::InvalidQuery)?;
let payload: String = row.get(2)?;
let payload: Vec<u8> = row.get(2)?;
Ok(RPCNetwork {
id: network_id,
r#type: network_type.into(),
Expand All @@ -152,7 +165,7 @@ pub fn list_networks(conn: &Connection) -> Result<Vec<RPCNetwork>> {
}

pub fn reorder_network(conn: &Connection, req: NetworkReorderRequest) -> Result<()> {
let mut stmt = conn.prepare("UPDATE network SET index = ? WHERE id = ?")?;
let mut stmt = conn.prepare("UPDATE network SET idx = ? WHERE id = ?")?;
let res = stmt.execute(params![req.index, req.id])?;
if res == 0 {
return Err(anyhow::anyhow!("No such network exists"));
Expand Down
97 changes: 95 additions & 2 deletions burrow/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ enum Commands {
ServerStatus,
/// Tunnel Config
TunnelConfig,
/// Add Network
NetworkAdd(NetworkAddArgs),
/// List Networks
NetworkList,
/// Reorder Network
NetworkReorder(NetworkReorderArgs),
/// Delete Network
NetworkDelete(NetworkDeleteArgs),
}

#[derive(Args)]
Expand All @@ -72,6 +80,24 @@ struct StartArgs {}
#[derive(Args)]
struct DaemonArgs {}

#[derive(Args)]
struct NetworkAddArgs {
id: i32,
network_type: i32,
payload_path: String,
}

#[derive(Args)]
struct NetworkReorderArgs {
id: i32,
index: i32,
}

#[derive(Args)]
struct NetworkDeleteArgs {
id: i32,
}

#[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_start() -> Result<()> {
let mut client = BurrowClient::from_uds().await?;
Expand Down Expand Up @@ -120,6 +146,69 @@ async fn try_tun_config() -> Result<()> {
Ok(())
}

#[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_network_add(id: i32, network_type: i32, payload_path: &str) -> Result<()> {
use tokio::{fs::File, io::AsyncReadExt};

use crate::daemon::rpc::grpc_defs::Network;

let mut file = File::open(payload_path).await?;
let mut payload = Vec::new();
file.read_to_end(&mut payload).await?;

let mut client = BurrowClient::from_uds().await?;
let network = Network {
id,
r#type: network_type,
payload,
};
let res = client.networks_client.network_add(network).await?;
println!("Network Add Response: {:?}", res);
Ok(())
}

#[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_network_list() -> Result<()> {
let mut client = BurrowClient::from_uds().await?;
let mut res = client
.networks_client
.network_list(Empty {})
.await?
.into_inner();
while let Some(network_list) = res.message().await? {
println!("Network List: {:?}", network_list);
}
Ok(())
}

#[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_network_reorder(id: i32, index: i32) -> Result<()> {
use crate::daemon::rpc::grpc_defs::NetworkReorderRequest;

let mut client = BurrowClient::from_uds().await?;
let reorder_request = NetworkReorderRequest { id, index };
let res = client
.networks_client
.network_reorder(reorder_request)
.await?;
println!("Network Reorder Response: {:?}", res);
Ok(())
}

#[cfg(any(target_os = "linux", target_vendor = "apple"))]
async fn try_network_delete(id: i32) -> Result<()> {
use crate::daemon::rpc::grpc_defs::NetworkDeleteRequest;

let mut client = BurrowClient::from_uds().await?;
let delete_request = NetworkDeleteRequest { id };
let res = client
.networks_client
.network_delete(delete_request)
.await?;
println!("Network Delete Response: {:?}", res);
Ok(())
}

#[cfg(any(target_os = "linux", target_vendor = "apple"))]
fn handle_unexpected(res: Result<DaemonResponseData, String>) {
match res {
Expand Down Expand Up @@ -176,8 +265,6 @@ async fn try_reloadconfig(interface_id: String) -> Result<()> {
#[cfg(any(target_os = "linux", target_vendor = "apple"))]
#[tokio::main]
async fn main() -> Result<()> {
use daemon::get_socket_path;

tracing::initialize();
dotenv::dotenv().ok();

Expand All @@ -192,6 +279,12 @@ async fn main() -> Result<()> {
Commands::AuthServer => crate::auth::server::serve().await?,
Commands::ServerStatus => try_serverstatus().await?,
Commands::TunnelConfig => try_tun_config().await?,
Commands::NetworkAdd(args) => {
try_network_add(args.id, args.network_type, &args.payload_path).await?
}
Commands::NetworkList => try_network_list().await?,
Commands::NetworkReorder(args) => try_network_reorder(args.id, args.index).await?,
Commands::NetworkDelete(args) => try_network_delete(args.id).await?,
}

Ok(())
Expand Down

0 comments on commit 8d1f0a9

Please sign in to comment.