Skip to content

Commit

Permalink
fix(schema-engine): Ensure WS migrations can use shadow database (pri…
Browse files Browse the repository at this point in the history
…sma#5021)

* fix(schema-engine): Ensure WS migrations can use shadow database

Previousy, when creating shadow DB connection over websocket, we
connected to the same DB which broke in every `migrate` case except the
one that starts with clean migration history.

This PR ensures it works normally. Implementation is quite cursed
though: for WS we now allow to override db name via `dbname` query
string parameter. If set, we ignore `dbname` that we got from migration
server and use provided DB with the same username and password. Shadow
DB then uses this query string parameter to specify the url.

Close ORM-325

* Clippy + rustfmt
  • Loading branch information
SevInf authored Oct 17, 2024
1 parent 08713a9 commit b9903dc
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
6 changes: 5 additions & 1 deletion quaint/src/connector/postgres/native/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters";
const HOST_HEADER: &str = "Prisma-Db-Host";

pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result<Client> {
let db_name = url.overriden_db_name().map(ToOwned::to_owned);
let (ws_stream, response) = connect_async(url).await?;

let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?;
let db_host = require_header_value(response.headers(), HOST_HEADER)?;

let config = Config::from_str(connection_params)?;
let mut config = Config::from_str(connection_params)?;
if let Some(db_name) = db_name {
config.dbname(&db_name);
}
let ws_byte_stream = WsStream::new(ws_stream);

let tls = TlsConnector::new(native_tls::TlsConnector::new()?, db_host);
Expand Down
21 changes: 19 additions & 2 deletions quaint/src/connector/postgres/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl PostgresUrl {
pub fn dbname(&self) -> &str {
match self {
Self::Native(url) => url.dbname(),
Self::WebSocket(_) => "postgres",
Self::WebSocket(url) => url.dbname(),
}
}

Expand Down Expand Up @@ -493,17 +493,34 @@ pub(crate) struct PostgresUrlQueryParams {
pub struct PostgresWebSocketUrl {
pub(crate) url: Url,
pub(crate) api_key: String,
pub(crate) db_name: Option<String>,
}

impl PostgresWebSocketUrl {
pub fn new(url: Url, api_key: String) -> Self {
Self { url, api_key }
Self {
url,
api_key,
db_name: None,
}
}

pub fn override_db_name(&mut self, name: String) {
self.db_name = Some(name)
}

pub fn api_key(&self) -> &str {
&self.api_key
}

pub fn dbname(&self) -> &str {
self.overriden_db_name().unwrap_or("postgres")
}

pub fn overriden_db_name(&self) -> Option<&str> {
self.db_name.as_deref()
}

pub fn host(&self) -> &str {
self.url.host_str().unwrap_or("localhost")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ use crate::SqlFlavour;
use enumflags2::BitFlags;
use indoc::indoc;
use once_cell::sync::Lazy;
use quaint::{connector::PostgresUrl, prelude::NativeConnectionInfo, Value};
use quaint::{
connector::{PostgresUrl, PostgresWebSocketUrl},
prelude::NativeConnectionInfo,
Value,
};
use schema_connector::{
migrations_directory::MigrationDirectory, BoxFuture, ConnectorError, ConnectorParams, ConnectorResult, Namespaces,
};
Expand Down Expand Up @@ -41,6 +45,7 @@ static MIGRATE_WS_BASE_URL: Lazy<Cow<'static, str>> = Lazy::new(|| {
impl MigratePostgresUrl {
const WEBSOCKET_SCHEME: &'static str = "prisma+postgres";
const API_KEY_PARAM: &'static str = "api_key";
const DBNAME_PARAM: &'static str = "dbname";

fn new(url: Url) -> ConnectorResult<Self> {
let postgres_url = if url.scheme() == Self::WEBSOCKET_SCHEME {
Expand All @@ -50,7 +55,14 @@ impl MigratePostgresUrl {
"Required `api_key` query string parameter was not provided in a connection URL",
));
};
PostgresUrl::new_websocket(ws_url, api_key.into_owned())

let dbname_override = url.query_pairs().find(|(name, _)| name == Self::DBNAME_PARAM);
let mut ws_url = PostgresWebSocketUrl::new(ws_url, api_key.into_owned());
if let Some((_, dbname_override)) = dbname_override {
ws_url.override_db_name(dbname_override.into_owned());
}

Ok(PostgresUrl::WebSocket(ws_url))
} else {
PostgresUrl::new_native(url)
}
Expand Down Expand Up @@ -514,7 +526,14 @@ impl SqlFlavour for PostgresFlavour {
.connection_string
.parse()
.map_err(ConnectorError::url_parse_error)?;
shadow_database_url.set_path(&format!("/{shadow_database_name}"));

if shadow_database_url.scheme() == MigratePostgresUrl::WEBSOCKET_SCHEME {
shadow_database_url
.query_pairs_mut()
.append_pair(MigratePostgresUrl::DBNAME_PARAM, &shadow_database_name);
} else {
shadow_database_url.set_path(&format!("/{shadow_database_name}"));
}
let shadow_db_params = ConnectorParams {
connection_string: shadow_database_url.to_string(),
preview_features: params.connector_params.preview_features,
Expand Down

0 comments on commit b9903dc

Please sign in to comment.