Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,048 changes: 485 additions & 563 deletions src-tauri/Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ russh = "0.48"
russh-keys = "0.48"
async-trait = "0.1"
tokio = { version = "1", features = ["sync", "net", "io-util"] }
tiberius = { version = "0.12", default-features = false, features = ["tokio", "chrono", "tds73", "native-tls"] }
async-native-tls = "0.5"
tokio-util = { version = "0.7", features = ["compat"] }
base64 = "0.22"
tauri-plugin-os = "2"
tauri-plugin-clipboard-manager = "2.3.2"
arboard = "3.6.1"
Expand Down
7 changes: 7 additions & 0 deletions src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use tauri::menu::{AboutMetadata, Menu, MenuItemBuilder, PredefinedMenuItem, Subm
use tauri::{Emitter, Manager};
use tauri_plugin_updater::UpdaterExt;

mod mssql;
mod ssh_tunnel;

use mssql::MssqlConnectionManager;
use ssh_tunnel::TunnelManager;

struct PendingUpdate {
Expand Down Expand Up @@ -157,6 +159,7 @@ pub fn run() {
tauri::Builder::default()
.plugin(tauri_plugin_os::init())
.manage(TunnelManager::new())
.manage(MssqlConnectionManager::new())
.manage(PendingUpdate { bytes: Mutex::new(None) })
.plugin(tauri_plugin_updater::Builder::new().build())
.plugin(tauri_plugin_process::init())
Expand All @@ -173,6 +176,10 @@ pub fn run() {
ssh_tunnel::close_ssh_tunnel,
ssh_tunnel::check_tunnel_status,
ssh_tunnel::list_active_tunnels,
mssql::mssql_connect,
mssql::mssql_disconnect,
mssql::mssql_query,
mssql::mssql_execute,
])
.setup(|app| {
// Set up custom menu
Expand Down
270 changes: 270 additions & 0 deletions src-tauri/src/mssql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
use async_native_tls::TlsStream;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tauri::State;
use tiberius::{AuthMethod, Client, Config, Query, Row};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};

#[derive(Debug, Serialize, Deserialize)]
pub struct MssqlConfig {
pub host: String,
pub port: u16,
pub database: String,
pub username: String,
pub password: String,
pub encrypt: Option<bool>,
pub trust_cert: Option<bool>,
}

#[derive(Debug, Serialize)]
pub struct MssqlConnection {
pub connection_id: String,
}

#[derive(Debug, Serialize)]
pub struct MssqlQueryResult {
pub columns: Vec<String>,
pub rows: Vec<serde_json::Value>,
pub rows_affected: u64,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct MssqlError {
pub message: String,
pub code: String,
}

impl std::fmt::Display for MssqlError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.code, self.message)
}
}

impl std::error::Error for MssqlError {}

type MssqlClient = Client<TlsStream<Compat<TcpStream>>>;

struct ConnectionHandle {
client: MssqlClient,
}

pub struct MssqlConnectionManager {
connections: Arc<Mutex<HashMap<String, ConnectionHandle>>>,
next_id: Arc<Mutex<u64>>,
}

impl MssqlConnectionManager {
pub fn new() -> Self {
Self {
connections: Arc::new(Mutex::new(HashMap::new())),
next_id: Arc::new(Mutex::new(1)),
}
}
}

impl Default for MssqlConnectionManager {
fn default() -> Self {
Self::new()
}
}

fn row_to_json(row: &Row) -> serde_json::Value {
let mut obj = serde_json::Map::new();
for col in row.columns() {
let col_name = col.name().to_string();
// Try to get value as different types, falling back through common types
// Start with string since SQL Server often returns nvarchar
let value = if let Some(v) = row.try_get::<&str, _>(col_name.as_str()).ok().flatten() {
serde_json::json!(v)
} else if let Some(v) = row.try_get::<i64, _>(col_name.as_str()).ok().flatten() {
serde_json::json!(v)
} else if let Some(v) = row.try_get::<i32, _>(col_name.as_str()).ok().flatten() {
serde_json::json!(v)
} else if let Some(v) = row.try_get::<i16, _>(col_name.as_str()).ok().flatten() {
serde_json::json!(v)
} else if let Some(v) = row.try_get::<u8, _>(col_name.as_str()).ok().flatten() {
serde_json::json!(v)
} else if let Some(v) = row.try_get::<f64, _>(col_name.as_str()).ok().flatten() {
serde_json::json!(v)
} else if let Some(v) = row.try_get::<f32, _>(col_name.as_str()).ok().flatten() {
serde_json::json!(v)
} else if let Some(v) = row.try_get::<bool, _>(col_name.as_str()).ok().flatten() {
serde_json::json!(v)
} else if let Some(v) = row.try_get::<&[u8], _>(col_name.as_str()).ok().flatten() {
// Binary data - encode as base64
use base64::{Engine as _, engine::general_purpose::STANDARD};
serde_json::json!(STANDARD.encode(v))
} else {
// NULL or unsupported type (dates, decimals, GUIDs handled as NULL for now)
// These would require additional feature flags in tiberius
serde_json::Value::Null
};
obj.insert(col_name, value);
}
serde_json::Value::Object(obj)
}

#[tauri::command]
pub async fn mssql_connect(
config: MssqlConfig,
manager: State<'_, MssqlConnectionManager>,
) -> Result<MssqlConnection, MssqlError> {
let mut tiberius_config = Config::new();

tiberius_config.host(&config.host);
tiberius_config.port(config.port);
tiberius_config.database(&config.database);
tiberius_config.authentication(AuthMethod::sql_server(&config.username, &config.password));

// We handle TLS manually, so tell tiberius not to do encryption
tiberius_config.encryption(tiberius::EncryptionLevel::NotSupported);

// Connect with timeout
let tcp = tokio::time::timeout(
std::time::Duration::from_secs(30),
TcpStream::connect(tiberius_config.get_addr()),
)
.await
.map_err(|_| MssqlError {
message: "Connection timed out".to_string(),
code: "TIMEOUT".to_string(),
})?
.map_err(|e| MssqlError {
message: format!("Failed to connect: {}", e),
code: "CONNECTION_ERROR".to_string(),
})?;

tcp.set_nodelay(true).map_err(|e| MssqlError {
message: format!("Failed to set TCP nodelay: {}", e),
code: "TCP_ERROR".to_string(),
})?;

// Wrap TCP stream with compat for futures-io trait compatibility
let tcp_compat = tcp.compat();

// Wrap with TLS - Azure SQL requires encryption
let tls_connector = async_native_tls::TlsConnector::new()
.danger_accept_invalid_certs(config.trust_cert.unwrap_or(true))
.use_sni(true);

let tls_stream = tls_connector
.connect(&config.host, tcp_compat)
.await
.map_err(|e| MssqlError {
message: format!("TLS connection failed: {}", e),
code: "TLS_ERROR".to_string(),
})?;

// TlsStream already implements futures-io traits, pass directly to tiberius
let client = Client::connect(tiberius_config, tls_stream)
.await
.map_err(|e| MssqlError {
message: format!("Failed to connect to SQL Server: {}", e),
code: "AUTH_ERROR".to_string(),
})?;

// Generate connection ID
let connection_id = {
let mut next_id = manager.next_id.lock().await;
let id = format!("mssql-{}", *next_id);
*next_id += 1;
id
};

// Store connection
{
let mut connections = manager.connections.lock().await;
connections.insert(connection_id.clone(), ConnectionHandle { client });
}

Ok(MssqlConnection { connection_id })
}

#[tauri::command]
pub async fn mssql_disconnect(
connection_id: String,
manager: State<'_, MssqlConnectionManager>,
) -> Result<(), MssqlError> {
let mut connections = manager.connections.lock().await;

if connections.remove(&connection_id).is_some() {
Ok(())
} else {
Err(MssqlError {
message: format!("Connection not found: {}", connection_id),
code: "CONNECTION_NOT_FOUND".to_string(),
})
}
}

#[tauri::command]
pub async fn mssql_query(
connection_id: String,
sql: String,
manager: State<'_, MssqlConnectionManager>,
) -> Result<MssqlQueryResult, MssqlError> {
let mut connections = manager.connections.lock().await;

let handle = connections.get_mut(&connection_id).ok_or(MssqlError {
message: format!("Connection not found: {}", connection_id),
code: "CONNECTION_NOT_FOUND".to_string(),
})?;

let query = Query::new(&sql);
let stream = query.query(&mut handle.client).await.map_err(|e| MssqlError {
message: format!("Query failed: {}", e),
code: "QUERY_ERROR".to_string(),
})?;

let rows = stream.into_first_result().await.map_err(|e| MssqlError {
message: format!("Failed to fetch results: {}", e),
code: "FETCH_ERROR".to_string(),
})?;

// Get column names from first row or return empty result
let columns: Vec<String> = if !rows.is_empty() {
rows[0].columns().iter().map(|c| c.name().to_string()).collect()
} else {
vec![]
};

// Convert rows to JSON
let json_rows: Vec<serde_json::Value> = rows
.iter()
.map(|row| row_to_json(row))
.collect();

Ok(MssqlQueryResult {
columns,
rows: json_rows,
rows_affected: 0,
})
}

#[tauri::command]
pub async fn mssql_execute(
connection_id: String,
sql: String,
manager: State<'_, MssqlConnectionManager>,
) -> Result<MssqlQueryResult, MssqlError> {
let mut connections = manager.connections.lock().await;

let handle = connections.get_mut(&connection_id).ok_or(MssqlError {
message: format!("Connection not found: {}", connection_id),
code: "CONNECTION_NOT_FOUND".to_string(),
})?;

let result = handle.client.execute(&sql, &[]).await.map_err(|e| MssqlError {
message: format!("Execute failed: {}", e),
code: "EXECUTE_ERROR".to_string(),
})?;

Ok(MssqlQueryResult {
columns: vec![],
rows: vec![],
rows_affected: result.rows_affected().iter().sum(),
})
}
12 changes: 6 additions & 6 deletions src/lib/components/app-header.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

const handleConnectionClick = (connection: typeof db.state.connections[0]) => {
// If connection has a database instance, just activate it
if (connection.database) {
if (connection.database || connection.mssqlConnectionId) {
db.connections.setActive(connection.id);
} else {
// If no database (persisted connection), open dialog with prefilled values
Expand Down Expand Up @@ -95,7 +95,7 @@
<span
class={[
"size-2 rounded-full shrink-0",
db.state.activeConnection.database ? "bg-green-500" : "bg-gray-400"
(db.state.activeConnection.database || db.state.activeConnection.mssqlConnectionId) ? "bg-green-500" : "bg-gray-400"
]}
></span>
<span class="max-w-32 truncate" title={db.state.activeConnection.name}>{db.state.activeConnection.name}</span>
Expand All @@ -117,9 +117,9 @@
<span
class={[
"size-2 rounded-full shrink-0",
connection.database ? "bg-green-500" : "bg-gray-400"
(connection.database || connection.mssqlConnectionId) ? "bg-green-500" : "bg-gray-400"
]}
title={connection.database ? m.header_connected() : m.header_disconnected()}
title={(connection.database || connection.mssqlConnectionId) ? m.header_connected() : m.header_disconnected()}
></span>
<span class="flex-1 truncate">{connection.name}</span>
{#if db.state.activeConnectionId === connection.id}
Expand All @@ -128,7 +128,7 @@
</DropdownMenu.Item>
</ContextMenu.Trigger>
<ContextMenu.Content class="w-40">
{#if connection.database}
{#if connection.database || connection.mssqlConnectionId}
<ContextMenu.Item onclick={() => db.connections.toggle(connection.id)}>
{m.header_disconnect()}
</ContextMenu.Item>
Expand Down Expand Up @@ -170,7 +170,7 @@
{/if}
</div>
<div class="flex items-center gap-1">
{#if db.state.activeConnection?.database}
{#if db.state.activeConnection?.database || db.state.activeConnection?.mssqlConnectionId}
<Button
size="icon"
variant="ghost"
Expand Down
6 changes: 3 additions & 3 deletions src/lib/components/command-palette.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
const openTabs = $derived(db.tabs.ordered);
const activeResult = $derived(db.state.activeQueryResult);
const hasResults = $derived((activeResult?.rows?.length ?? 0) > 0);
const isConnected = $derived(!!db.state.activeConnectionId && !!db.state.activeConnection?.database);
const isConnected = $derived(!!db.state.activeConnectionId && !!(db.state.activeConnection?.database || db.state.activeConnection?.mssqlConnectionId));
const hasActiveQueryTab = $derived(isConnected && !!db.state.activeQueryTab);
const hasQueryContent = $derived(hasActiveQueryTab && !!db.state.activeQueryTab?.query?.trim());
const hasConnections = $derived(connections.length > 0);
Expand Down Expand Up @@ -137,7 +137,7 @@
const connection = connections.find((c) => c.id === id);
if (!connection) return;

if (connection.database) {
if (connection.database || connection.mssqlConnectionId) {
// Already connected, just switch to it
runAndClose(() => db.connections.setActive(id));
} else {
Expand Down Expand Up @@ -370,7 +370,7 @@
</span>
{#if connection.id === db.state.activeConnectionId}
<span class="text-muted-foreground ms-auto text-xs">{m.command_status_active()}</span>
{:else if !connection.database}
{:else if !(connection.database || connection.mssqlConnectionId)}
<span class="text-muted-foreground ms-auto text-xs">{m.command_status_disconnected()}</span>
{/if}
</Command.Item>
Expand Down
Loading
Loading