Skip to content
Snippets Groups Projects
Commit e97d1cf0 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Modify check_database to take a connection rather than a cursor

We might not need the cursor at all.
parent d2906fe6
No related branches found
No related tags found
No related merge requests found
...@@ -447,15 +447,6 @@ class Porter(object): ...@@ -447,15 +447,6 @@ class Porter(object):
else: else:
return return
def setup_db(self, db_config: DatabaseConnectionConfig, engine):
db_conn = make_conn(db_config, engine)
prepare_database(db_conn, engine, config=None)
db_conn.commit()
return db_conn
@defer.inlineCallbacks
def build_db_store(self, db_config: DatabaseConnectionConfig): def build_db_store(self, db_config: DatabaseConnectionConfig):
"""Builds and returns a database store using the provided configuration. """Builds and returns a database store using the provided configuration.
...@@ -468,16 +459,14 @@ class Porter(object): ...@@ -468,16 +459,14 @@ class Porter(object):
self.progress.set_state("Preparing %s" % db_config.config["name"]) self.progress.set_state("Preparing %s" % db_config.config["name"])
engine = create_engine(db_config.config) engine = create_engine(db_config.config)
conn = self.setup_db(db_config, engine)
hs = MockHomeserver(self.hs_config) hs = MockHomeserver(self.hs_config)
store = Store(Database(hs, db_config, engine), conn, hs) with make_conn(db_config, engine) as db_conn:
engine.check_database(db_conn)
yield store.db.runInteraction( prepare_database(db_conn, engine, config=None)
"%s_engine.check_database" % db_config.config["name"], store = Store(Database(hs, db_config, engine), db_conn, hs)
engine.check_database, db_conn.commit()
)
return store return store
...@@ -502,7 +491,7 @@ class Porter(object): ...@@ -502,7 +491,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def run(self): def run(self):
try: try:
self.sqlite_store = yield self.build_db_store( self.sqlite_store = self.build_db_store(
DatabaseConnectionConfig("master-sqlite", self.sqlite_config) DatabaseConnectionConfig("master-sqlite", self.sqlite_config)
) )
...@@ -518,7 +507,7 @@ class Porter(object): ...@@ -518,7 +507,7 @@ class Porter(object):
) )
defer.returnValue(None) defer.returnValue(None)
self.postgres_store = yield self.build_db_store( self.postgres_store = self.build_db_store(
self.hs_config.get_single_database() self.hs_config.get_single_database()
) )
......
...@@ -47,7 +47,7 @@ class DataStores(object): ...@@ -47,7 +47,7 @@ class DataStores(object):
with make_conn(database_config, engine) as db_conn: with make_conn(database_config, engine) as db_conn:
logger.info("Preparing database %r...", db_name) logger.info("Preparing database %r...", db_name)
engine.check_database(db_conn.cursor()) engine.check_database(db_conn)
prepare_database( prepare_database(
db_conn, engine, hs.config, data_stores=database_config.data_stores, db_conn, engine, hs.config, data_stores=database_config.data_stores,
) )
......
...@@ -32,14 +32,15 @@ class PostgresEngine(object): ...@@ -32,14 +32,15 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True) self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet self._version = None # unknown as yet
def check_database(self, txn): def check_database(self, db_conn):
txn.execute("SHOW SERVER_ENCODING") with db_conn.cursor() as txn:
rows = txn.fetchall() txn.execute("SHOW SERVER_ENCODING")
if rows and rows[0][0] != "UTF8": rows = txn.fetchall()
raise IncorrectDatabaseSetup( if rows and rows[0][0] != "UTF8":
"Database has incorrect encoding: '%s' instead of 'UTF8'\n" raise IncorrectDatabaseSetup(
"See docs/postgres.rst for more information." % (rows[0][0],) "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
) "See docs/postgres.rst for more information." % (rows[0][0],)
)
def convert_param_style(self, sql): def convert_param_style(self, sql):
return sql.replace("?", "%s") return sql.replace("?", "%s")
......
...@@ -53,7 +53,7 @@ class Sqlite3Engine(object): ...@@ -53,7 +53,7 @@ class Sqlite3Engine(object):
""" """
return False return False
def check_database(self, txn): def check_database(self, db_conn):
pass pass
def convert_param_style(self, sql): def convert_param_style(self, sql):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment