# -*- coding: utf-8 -*-
"""
Support for using :mod:`testgres` to create and use a Postgres
instance as a layer.
There is also support for benchmarking and saving databeses for
later examination.
.. versionadded:: 4.0.0
The APIs are preliminary and may change.
This is only supported on platforms that can install ``psycopg2``.
"""
from contextlib import contextmanager
import functools
import os
import sys
import unittest
from unittest.mock import patch
#import psycopg2
#import psycopg2.extras
#import psycopg2.pool
try:
from psycopg2 import ProgrammingError
except ImportError:
ThreadedConnectionPool = None
DictCursor = None
class IntegrityError(Exception):
"""Never thrown"""
ProgrammingError = InternalError = IntegrityError
else:
from psycopg2.pool import ThreadedConnectionPool
from psycopg2.extras import DictCursor
from psycopg2 import IntegrityError
from psycopg2 import InternalError
import testgres
if 'PG_CONFIG' not in os.environ:
# Set up for macports and fedora, using files that exist.
# Otherwise, don't set it, assume things are on the path.
for option in (
'/opt/local/lib/postgresql11/bin/pg_config',
'/usr/pgsql-11/bin/pg_config',
):
if os.path.isfile(option):
# TODO: Check exec bit
os.environ['PG_CONFIG'] = option
break
# If True, save the database to a pg_dump
# file on teardown. The file name will be printed.''
SAVE_DATABASE_ON_TEARDOWN = False
SAVE_DATABASE_FILENAME = None
# If the path to a database dump file that exists, the database
# will be restored from this file on setUp.
LOAD_DATABASE_ON_SETUP = None
if 'NTI_SAVE_DB' in os.environ:
# NTI_SAVE_DB is either 1/on/true (case-insensitive)
# or a file name.
val = os.environ['NTI_SAVE_DB']
if val.lower() in {'0', 'off', 'false', 'no'}:
SAVE_DATABASE_ON_TEARDOWN = False
else:
SAVE_DATABASE_ON_TEARDOWN = True
if val.lower() not in {'1', 'on', 'true', 'yes'}:
SAVE_DATABASE_FILENAME = val
if 'NTI_LOAD_DB_FILE' in os.environ:
LOAD_DATABASE_ON_SETUP = os.environ['NTI_LOAD_DB_FILE']
def patched_get_pg_version(*args, **kwargs):
# We patch this in testgres.node, so its ok to import
# the original. In version 1.10, they changed the signature
# of this function, so be sure to accept whatever it does and
# pass it on. In version 1.11, this was replaced with
# get_pg_version2, which does the same thing just takes
# more arguments.
from testgres.utils import get_pg_version2
from testgres.node import PgVer
from packaging.version import InvalidVersion
# Some installs of postgres return
# strings that the version parser doesn't like;
# notably the Python images get a debian build that says
# "15.3-0+deb12u1" which get_pg_version() chops down to
# "15.3-0+". If it can't be parsed, then return a fake.
try:
version = get_pg_version2(*args, **kwargs)
PgVer(version)
except InvalidVersion:
print('testgres: Got invalid postgres version', version)
# The actual version string looks like "postgres (PostgreSQL) 15.4",
# and get_pg_version() processes that down to this
version = "15.4"
print('testgres: Substituting version', version)
return version
[docs]
class DatabaseLayer(object):
"""
A test layer that creates the database, and sets each
test up in its own connection, aborting the transaction when
done.
"""
#: The name of the database within the node. We only create
#: the default databases, so this should be 'postgres'
DATABASE_NAME = 'postgres'
#: A `testgres.node.PostgresNode`, created for the layer.
#: A psycopg2 connection to it is located in the :attr:`connection`
#: attribute (similarly for :attr:`connection_pool`), while
#: a DSN connection string is in :attr:`postgres_dsn`
postgres_node = None
#: A string you can use to connect to Postgres.
postgres_dsn = None
#: A string you can pass to SQLAlchemy
postgres_uri = None
#: Set for each test.
connection = None
#: Set for each test.
cursor = None
connection_pool = None
connection_pool_klass = ThreadedConnectionPool
connection_pool_minconn = 1
connection_pool_maxconn = 51
@classmethod
def setUp(cls):
testgres.configure_testgres()
with patch('testgres.node.get_pg_version2', new=patched_get_pg_version):
node = cls.postgres_node = testgres.get_new_node()
# init takes about about 2 -- 3 seconds
node.init(
# Use the encoding as UTF-8. Set the locale as POSIX
# instead of inheriting it (in JAM's environment, the locale
# and thus collation and ctype is en_US.UTF8; this turns out to be
# up to 40% slower than POSIX).
# We could explicitly specify 'en-x-icu' on each column, if we required
# ICU support, but it cannot be used as a default collation.
initdb_params=[
"-E", "UTF8",
'--locale', 'POSIX',
# Don't force to disk; this may save some minor init time.
'--no-sync',
],
log_statement='none',
# Disable unix sockets. Some platforms might try to put this
# in a directory we can't write to
unix_sockets=False
)
# Speed up bulk inserts
# These settings appeared to make no difference for the
# 2 million security insert or the 500K security mapping insert;
# likely because the final table sizes are < 300MB, so the default max size of
# 1GB is more than enough.
node.append_conf('fsync = off')
node.append_conf('full_page_writes = off')
node.append_conf('min_wal_size = 500MB')
node.append_conf('max_wal_size = 2GB')
# 'replica' is the default. If we use 'minimal' we could be
# a bit faster, but that's not exactly realistic. Plus,
# using 'minimal' disables the WAL backup functionality.
# If we set to 'minimal', we must also set 'max_wal_senders' to 0.
node.append_conf('wal_level = replica')
node.append_conf('wal_compression = on')
node.append_conf('wal_writer_delay = 10000ms')
node.append_conf('wal_writer_flush_after = 10MB')
node.append_conf('temp_buffers = 500MB')
node.append_conf('work_mem = 500MB')
node.append_conf('maintenance_work_mem = 500MB')
node.append_conf('shared_buffers = 500MB')
node.append_conf('max_connections = 100')
# auto-explain for slow queries
if 'benchmark' in ' '.join(sys.argv):
print("Enabling BENCHMARK SETTINGS")
node.append_conf('shared_preload_libraries = auto_explain')
node.append_conf('auto_explain.log_min_duration = 40ms')
node.append_conf('auto_explain.log_nested_statements = on')
node.append_conf('auto_explain.log_analyze = on')
node.append_conf('auto_explain.log_timing = on')
node.append_conf('auto_explain.log_triggers = on')
# PG 11 only, when --with-llvm was used to compile.
# It seems if it can't be used, it's ignored? It errors on 10 though,
# but we only support 11
node.append_conf('jit = on')
node.start()
cls.connection_pool = cls.connection_pool_klass(
cls.connection_pool_minconn,
cls.connection_pool_maxconn,
dbname=cls.DATABASE_NAME,
host='localhost',
port=cls.postgres_node.port,
cursor_factory=DictCursor,
)
cls.postgres_dsn = "host=%s dbname=%s port=%s" % (
node.host, cls.DATABASE_NAME, node.port
)
cls.postgres_uri = "postgresql://%s:%s/%s" % (
cls.postgres_node.host,
cls.postgres_node.port,
cls.DATABASE_NAME
)
with cls.borrowed_connection() as conn:
with conn.cursor() as cur:
i = cls.__get_db_info(cur)
print(f"({i['version']} {i['current_database']}/{i['current_schema']} "
f"{i['Encoding']}-{i['Collate']}) ", end="")
@classmethod
def tearDown(cls):
cls.connection_pool.closeall()
cls.connection_pool = None
cls.postgres_node.__exit__(None, None, None)
cls.postgres_node = None
@classmethod
def testSetUp(cls):
# XXX: Errors here cause the tearDown method to not get called.
cls.connection = cls.connection_pool.getconn()
cls.cursor = cls.connection.cursor()
@classmethod
def testTearDown(cls):
cls.connection.rollback() # Make sure we're able to execute
cls.cursor.execute('UNLISTEN *')
cls.cursor.close()
cls.cursor = None
cls.connection_pool.putconn(cls.connection)
cls.connection = None
@classmethod
def __get_db_info(cls, cur):
query = """
SELECT version() as version,
d.datname as "Name",
pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding",
d.datcollate as "Collate",
d.datctype as "Ctype",
pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges",
current_database() as "current_database",
current_schema() as "current_schema"
FROM pg_catalog.pg_database d
WHERE d.datname = %s
"""
cur.execute(query, (cls.DATABASE_NAME,))
row = cur.fetchone()
return dict(row)
[docs]
@classmethod
@contextmanager
def borrowed_connection(cls):
"""
Context manager that returns a connection from the connection
pool.
"""
conn = cls.connection_pool.getconn()
try:
yield conn
finally:
cls.connection_pool.putconn(conn)
[docs]
@classmethod
def truncate_table(cls, conn, table_name):
"""Transactionally truncate the given *table_name* using *conn*"""
try:
with conn.cursor() as cur:
cur.execute(
'TRUNCATE TABLE ' + table_name + ' CASCADE'
)
except (ProgrammingError, InternalError):
# Table doesn't exist, not a full schema,
# ignore.
# OR:
# Already aborted
import traceback
traceback.print_exc()
conn.rollback()
else:
# Is PostgreSQL, TRUNCATE is transactional!
# Awesome!
conn.commit()
[docs]
@classmethod
def drop_relation(cls, relation, kind='TABLE', idempotent=False):
"""Drops the *relation* of type *kind* (default table), in new transaction."""
with cls.borrowed_connection() as conn:
with conn.cursor() as cur:
if idempotent:
cur.execute(f"DROP {kind} IF EXISTS {relation}")
else:
cur.execute(f"DROP {kind} {relation}")
conn.commit()
@classmethod
def vacuum(cls, *tables, **kwargs):
verbose = ''
if kwargs.pop('verbose', False):
verbose = ', VERBOSE'
with cls.borrowed_connection() as conn:
conn.autocommit = True
# FULL rewrites all tables and takes forever.
# FREEZE is simpler and compacts tables
stmt = f'VACUUM (FREEZE, ANALYZE {verbose}) '
tables = tables or ('',)
with conn.cursor() as cur:
# VACUUM cannot run inside a transaction block...
for t in tables:
cur.execute(stmt + t)
conn.autocommit = False
if verbose:
for n in conn.notices:
print(n)
del conn.notices[:]
if kwargs.pop('size_report', True):
cls.print_size_report()
ONLY_PRINT_SIZE_OF_TABLES = None
@classmethod
def print_size_report(cls):
extra_query = ''
if cls.ONLY_PRINT_SIZE_OF_TABLES:
t = cls.ONLY_PRINT_SIZE_OF_TABLES
extra_query = f"AND table_name = '{t}'"
query = f"""
SELECT table_name,
pg_size_pretty(total_bytes) AS total,
pg_size_pretty(index_bytes) AS INDEX,
pg_size_pretty(toast_bytes) AS toast,
pg_size_pretty(table_bytes) AS TABLE
FROM (
SELECT *, total_bytes-index_bytes-COALESCE(toast_bytes,0) AS table_bytes FROM (
SELECT c.oid,nspname AS table_schema, relname AS TABLE_NAME
, c.reltuples AS row_estimate
, pg_total_relation_size(c.oid) AS total_bytes
, pg_indexes_size(c.oid) AS index_bytes
, pg_total_relation_size(reltoastrelid) AS toast_bytes
FROM pg_class c
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE relkind = 'r'
) a
) a
WHERE (table_name NOT LIKE 'pg_%' and table_name not like 'abstract_%'
{extra_query}
)
AND table_schema <> 'pg_catalog' and table_schema <> 'information_schema'
ORDER BY total_bytes DESC
"""
with cls.borrowed_connection() as conn:
with conn.cursor() as cur:
cur.execute(query)
rows = [dict(row) for row in cur]
keys = ['table_name', 'total', 'index', 'toast', 'table']
rows.insert(0, {k: k for k in keys})
print()
fmt = "| {table_name:35s} | {total:10s} | {index:10s} | {toast:10s} | {table:10s}"
for row in rows:
if not extra_query and row['total'] in {
'72 kB', '32 kB', '24 kB', '16 kB', '8192 bytes'
}:
continue
print(fmt.format(
**{k: v if v else '<null>' for k, v in row.items()}
))
[docs]
class SchemaDatabaseLayer(DatabaseLayer):
"""
A test layer that adds our schema.
"""
SCHEMA_FILE = os.path.abspath(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'..', '..', '..', '..', '..',
'full_schema.sql'
))
@classmethod
def run_files(cls, *files):
for fname in files:
code, stdout, stderr = cls.postgres_node.psql(
filename=fname,
ON_ERROR_STOP=1
)
if code:
break
if code:
import subprocess
stdout = stdout.decode("utf-8")
stderr = stderr.decode('utf-8')
print(stdout)
print(stderr)
raise subprocess.CalledProcessError(
code,
'psql',
stdout,
stderr
)
@classmethod
def _tangle_schema_if_needed(cls):
# If the schema files do not exist, or db.org is newer
# than they are, run emacs to weave the files together.
# This requires a working emacs with org-mode available.
from pathlib import Path
import subprocess
cwd = Path(".")
# Each org file tangles to at least one sql file.
org_to_sql = {
org: org.with_suffix('.sql')
for org in cwd.glob("*.org")
}
org_to_sql[Path("db.org")] = Path("full_schema.sql")
for org, sql in org_to_sql.items():
if not org.exists():
continue
if sql.exists() and sql.stat().st_mtime >= org.stat().st_mtime:
continue
print(f"\nDatabase schema files outdated; tangling {org}")
ex = None
try:
output = subprocess.check_output([
"emacs",
"--batch",
"--eval",
f'''(progn
(package-initialize)
(require 'org)
(org-babel-tangle-file "{org}")
)'''
], stderr=subprocess.STDOUT)
except FileNotFoundError as e:
output = str(e).encode('utf-8')
ex = e
except subprocess.CalledProcessError as e:
output = ex.output
ex = e # pylint:disable=redefined-variable-type
output = output.decode('utf-8')
if ex is not None or 'Tangled 0' in output:
print("Failed to tangle database schema; "
"(check file paths):\n",
output,
file=sys.stderr)
sys.exit(1)
@classmethod
def setUp(cls):
cwd = os.getcwd()
try:
os.chdir(os.path.dirname(cls.SCHEMA_FILE))
# XXX: Do exceptions here prevent the super tearDown()
# from being called?
cls._tangle_schema_if_needed()
to_run = [cls.SCHEMA_FILE]
if os.path.exists("prereq.sql"):
to_run.insert(0, "prereq.sql")
cls.run_files(*to_run)
finally:
os.chdir(cwd)
@classmethod
def tearDown(cls):
pass
@classmethod
def testSetUp(cls):
pass
@classmethod
def testTearDown(cls):
pass
[docs]
class DatabaseBackupLayerHelper:
"""
A layer helper that works with another layer to
* create a backup of the current database on `push`;
* make that backup active;
* switch the connection pool to that backup
* reverse all of that on layer `pop`
Note that this consists of modifying values in the `DatabaseLayer`,
so the *layer* parameter must extend that.
"""
_nodes = []
_pools = []
@classmethod
def push(cls, layer):
current_node = DatabaseLayer.postgres_node
cls._nodes.append(current_node)
cls._pools.append(DatabaseLayer.connection_pool)
with layer.borrowed_connection() as conn:
with conn.cursor() as cur:
# If we don't checkpoint here, then the backup waits
# for the next WAL checkpoint to happen. We may not have
# written much to the WAL, so we could wait until a time limit
# expires, which is ofter 30+ seconds. We don't want to wait.
cur.execute('CHECKPOINT')
# A streaming backup uses a replication slot, but it
# does the copy in parallel.
backup = current_node.backup(xlog_method='stream')
DatabaseLayer.postgres_node = new_node = backup.spawn_primary()
new_node.start()
DatabaseLayer.connection_pool = layer.connection_pool_klass(
layer.connection_pool_minconn,
layer.connection_pool_maxconn,
dbname=layer.DATABASE_NAME,
host='localhost',
port=new_node.port,
cursor_factory=DictCursor
)
@classmethod
def pop(cls, layer): # pylint:disable=unused-argument
DatabaseLayer.tearDown() # Closes the current node, and the connection pool
DatabaseLayer.postgres_node = cls._nodes.pop()
DatabaseLayer.connection_pool = cls._pools.pop()
_persistent_base = (
# If we're loading a file, it has the schema
# info.
SchemaDatabaseLayer
if not LOAD_DATABASE_ON_SETUP
else DatabaseLayer
)
[docs]
class PersistentDatabaseLayer(_persistent_base):
"""
A layer that establishes persistent data visible to
all of its tests (and all of its sub-layers).
Sub-layers need to check whether they should
clean up or not, because we may be saving the database file.
It's important to have a fairly linear layer
setup, or layers that don't interfere with each other.
"""
@classmethod
def setUp(cls):
if LOAD_DATABASE_ON_SETUP:
print(f" (Loading database from {LOAD_DATABASE_ON_SETUP}) ",
end='',
flush=True)
cls.postgres_node.restore(LOAD_DATABASE_ON_SETUP)
cls.vacuum()
@classmethod
def testSetUp(cls):
pass
@classmethod
def testTearDown(cls):
pass
[docs]
@classmethod
def persistent_layer_skip_teardown(cls):
"""
Should persistent layers, that write data intended to be
visible between tests (and in sub-layers) tear down that data
when the layer is torn down? If we're saving the database, we
don't want to do that.
Raising NotImplementedError causes the testrunner to assume
it's python resources that are the problem and continue in a new
subprocess, which doesn't help (and may hurt?). So you must check this as a
boolean.
"""
return SAVE_DATABASE_ON_TEARDOWN
[docs]
@classmethod
def persistent_layer_skip_setup(cls):
"""
Should persistent layers skip their setup because
we loaded a save file?
"""
return LOAD_DATABASE_ON_SETUP
@classmethod
def tearDown(cls):
if SAVE_DATABASE_ON_TEARDOWN:
tmp_fname = cls.postgres_node.dump(format='custom')
result_fname = tmp_fname
if SAVE_DATABASE_FILENAME:
import shutil
result_fname = SAVE_DATABASE_FILENAME
while os.path.exists(result_fname):
result_fname += '.1'
shutil.move(tmp_fname, result_fname)
print(f" (Database dumped to {result_fname}) ", end='')
def persistent_skip_setup(func):
@functools.wraps(func)
def maybe_skip_setup(cls):
if cls.persistent_layer_skip_setup():
return
func(cls)
return maybe_skip_setup
def persistent_skip_teardown(func):
@functools.wraps(func)
def f(cls):
if cls.persistent_layer_skip_teardown():
return
func(cls)
return f
[docs]
class DatabaseTestCase(unittest.TestCase):
"""
A helper test base containing some functions useful for both
benchmarking and unit testing.
"""
# pylint:disable=no-member
@contextmanager
def assertRaisesIntegrityError(self, match=None):
if match:
with self.assertRaisesRegex(IntegrityError, match) as exc:
yield exc
else:
with self.assertRaises(IntegrityError) as exc:
yield exc
# We can't do any queries after an error is raised
# until we rollback.
self.layer.connection.rollback()
return exc
def assert_row_count_in_query(self, expected_count, query):
cur = self.layer.cursor
cur.execute('SELECT COUNT(*) FROM ' + query)
row = cur.fetchone()
count = row[0]
self.assertEqual(expected_count, count, query)
def assert_row_count_in_table(self, expected_count, table_name):
__traceback_info__ = table_name
self.assert_row_count_in_query(expected_count, table_name)
def assert_row_count_in_cursor(self, rowcount, cursor=None):
cur = cursor if cursor is not None else self.layer.cursor
self.assertEqual(cur.rowcount, rowcount)