Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions syncserver-db-common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ pub struct SqlError {
pub backtrace: Backtrace,
}

impl SqlError {
#[cfg(debug_assertions)]
pub fn is_diesel_not_found(&self) -> bool {
matches!(
self.kind,
SqlErrorKind::DieselQuery(diesel::result::Error::NotFound)
)
}
}

#[derive(Debug, Error)]
enum SqlErrorKind {
#[error("A database error occurred: {}", _0)]
Expand Down
5 changes: 5 additions & 0 deletions tokenserver-db-common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ impl DbError {
pub fn pool_timeout(timeout_type: deadpool::managed::TimeoutType) -> Self {
DbErrorKind::PoolTimeout(timeout_type).into()
}

#[cfg(debug_assertions)]
pub fn is_diesel_not_found(&self) -> bool {
matches!(&self.kind, DbErrorKind::Sql(e) if e.is_diesel_not_found())
}
}

impl ReportableError for DbError {
Expand Down
88 changes: 40 additions & 48 deletions tokenserver-db/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ async fn test_update_generation() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -88,11 +86,9 @@ async fn test_update_keys_changed_at() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -170,11 +166,9 @@ async fn replace_users() -> DbResult<()> {
.as_millis() as i64;
let an_hour_ago = now - MILLISECONDS_IN_AN_HOUR;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -346,11 +340,9 @@ async fn post_user() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -412,11 +404,9 @@ async fn get_node_id() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -459,11 +449,9 @@ async fn test_node_allocation() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -506,11 +494,9 @@ async fn test_allocation_to_least_loaded_node() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -571,11 +557,9 @@ async fn test_allocation_is_not_allowed_to_downed_nodes() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -614,11 +598,9 @@ async fn test_allocation_is_not_allowed_to_backoff_nodes() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -657,11 +639,9 @@ async fn test_node_reassignment_when_records_are_replaced() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -731,11 +711,9 @@ async fn test_node_reassignment_not_done_for_retired_users() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -787,11 +765,9 @@ async fn test_node_reassignment_and_removal() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -938,11 +914,9 @@ async fn test_gradual_release_of_node_capacity() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -1104,11 +1078,9 @@ async fn test_correct_created_at_used_during_node_reassignment() -> DbResult<()>
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -1168,11 +1140,9 @@ async fn test_correct_created_at_used_during_user_retrieval() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -1227,11 +1197,9 @@ async fn test_get_spanner_node() -> DbResult<()> {
let pool = db_pool().await?;
let mut db = pool.get().await?;

// Add a service
let service_id = db
.post_service(params::PostService {
.get_service_id(params::GetServiceId {
service: "sync-1.5".to_owned(),
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?
.id;
Expand Down Expand Up @@ -1309,5 +1277,29 @@ async fn db_pool() -> DbResult<Box<dyn DbPool>> {
use_test_transactions,
)?;
pool.init().await?;

if settings.tokenserver.database_url.starts_with("mysql://") {
// Ensure the "sync-1.5" service
// TODO: tokenserver-mysql's migration should add this service
// entry for us (if possible)
let mut db = pool.get().await?;
let service = "sync-1.5".to_owned();
let result = db
.get_service_id(params::GetServiceId {
service: service.clone(),
})
.await;
if let Err(e) = result {
if !e.is_diesel_not_found() {
return Err(e);
}
db.post_service(params::PostService {
service,
pattern: "{node}/1.5/{uid}".to_owned(),
})
.await?;
}
}

Ok(pool)
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
-- Create Tables
CREATE TABLE IF NOT EXISTS services (
id SERIAL PRIMARY KEY,
id INTEGER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
service VARCHAR(30) UNIQUE,
pattern VARCHAR(128)
);

CREATE TABLE IF NOT EXISTS nodes (
id BIGSERIAL PRIMARY KEY,
id BIGINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
service INTEGER NOT NULL,
node VARCHAR(64) NOT NULL,
available INTEGER NOT NULL,
Expand All @@ -18,7 +18,7 @@ CREATE TABLE IF NOT EXISTS nodes (
);

CREATE TABLE IF NOT EXISTS users (
uid BIGSERIAL PRIMARY KEY,
uid BIGINT GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
service INTEGER NOT NULL,
email VARCHAR(255) NOT NULL,
generation BIGINT NOT NULL,
Expand All @@ -34,4 +34,9 @@ CREATE INDEX IF NOT EXISTS lookup_idx ON users (email, service, created_at);

CREATE INDEX IF NOT EXISTS replaced_at_idx ON users (service, replaced_at);

CREATE INDEX IF NOT EXISTS node_idx ON users (nodeid);
CREATE INDEX IF NOT EXISTS node_idx ON users (nodeid);


-- The standard Sync service entry
INSERT INTO services (service, pattern) VALUES
('sync-1.5', '{node}/1.5/{uid}');
35 changes: 18 additions & 17 deletions tools/integration_tests/tokenserver/test_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,27 @@ def setUp(self):
)

# Start each test with a blank slate.
cursor = self._execute_sql(sqltext(("DELETE FROM users")), {})
cursor.close()
self._clear_db()

cursor = self._execute_sql((sqltext("DELETE FROM nodes")), {})
cursor.close()

cursor = self._execute_sql(sqltext(("DELETE FROM services")), {})
cursor.close()

self.service_id = self._add_service("sync-1.5", r"{node}/1.5/{uid}")
# TODO: tokenserver-mysql's migration should add this
# service entry for us (if possible)
self.service_id = self._get_or_add_service("sync-1.5", r"{node}/1.5/{uid}")

# Ensure we have a node with enough capacity to run the tests.
self._add_node(capacity=100, node=self.NODE_URL, id=self.NODE_ID)

def tearDown(self):
# And clean up at the end, for good measure.
self._clear_db()
self.database.close()

def _clear_db(self):
cursor = self._execute_sql(sqltext(("DELETE FROM users")), {})
cursor.close()

cursor = self._execute_sql(sqltext(("DELETE FROM nodes")), {})
cursor.close()

cursor = self._execute_sql(sqltext(("DELETE FROM services")), {})
cursor.close()

self.database.close()
# NOTE: don't clear the services between tests as tokenserver
# may have already cached its "sync-1.5" service_id

def _build_oauth_headers(
self,
Expand Down Expand Up @@ -337,10 +332,16 @@ def _get_replaced_users(self, service, email):
def _get_service_id(self, service):
query = sqltext("select id from services where service = :service")
cursor = self._execute_sql(query, {"service": service})
(service_id,) = cursor.fetchone()
row = cursor.fetchone()
cursor.close()

return service_id
return None if row is None else row[0]

def _get_or_add_service(self, service, pattern):
service_id = self._get_service_id(service)
if service_id is not None:
return service_id
return self._add_service(service, pattern)

def _count_users(self):
query = sqltext("select COUNT(DISTINCT(uid)) from users")
Expand Down