# Licensed under a BSD-style 3-clause license - see LICENSE.md.
# -*- coding: utf-8 -*-
"""
dlairflow.postgresql
====================
Standard tasks for working with PostgreSQL that can be imported into a DAG.
"""
import os
try:
from airflow.sdk.bases.hook import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook
from .util import ensure_sql
# _legacy_bash = False
try:
from airflow.providers.standard.operators.bash import BashOperator
except ImportError:
from airflow.operators.bash import BashOperator
# _legacy_bash = True
_legacy_postgres = False
try:
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator as PostgresOperator
except ImportError:
from airflow.providers.postgres.operators.postgres import PostgresOperator
_legacy_postgres = True
[docs]
def _connection_to_environment(connection):
"""Convert a database connection to environment variables.
Parameters
----------
connection : :class:`str`
An Airflow database connection string.
Returns
-------
:class:`dict`
A dictionary suitable for passing to the ``env`` keyword on, *e.g.*
:class:`~airflow.providers.standard.operators.bash.BashOperator`.
"""
conn = BaseHook.get_connection(connection)
env = {'PGUSER': conn.login,
'PGPASSWORD': conn.password,
'PGHOST': conn.host,
'PGDATABASE': conn.schema}
return env
[docs]
def _PostgresOperatorWrapper(**kwargs):
"""Handle different call signatures for PostgresOperator in different
versions of Airflow.
"""
if _legacy_postgres:
kwargs['postgres_conn_id'] = kwargs['conn_id']
del kwargs['conn_id']
return PostgresOperator(**kwargs)
[docs]
def pg_dump_schema(connection, schema, dump_dir):
"""Dump an entire database schema using :command:`pg_dump`.
Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
dump_dir : :class:`str`
Place the dump file in this directory.
Returns
-------
:class:`~airflow.providers.standard.operators.bash.BashOperator`
A BashOperator that will execute :command:`pg_dump`.
"""
pg_env = _connection_to_environment(connection)
return BashOperator(task_id="pg_dump_schema",
bash_command=("[[ -f {{ params.dump_dir }}/{{ params.schema }}.dump ]] || " +
"pg_dump --schema={{ params.schema }} --format=c " +
"--file={{ params.dump_dir }}/{{ params.schema }}.dump"),
params={'schema': schema,
'dump_dir': dump_dir},
env=pg_env,
append_env=True)
[docs]
def pg_restore_schema(connection, schema, dump_dir):
"""Restore a database schema using :command:`pg_restore`.
Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
dump_dir : :class:`str`
Find the dump file in this directory.
Returns
-------
:class:`~airflow.providers.standard.operators.bash.BashOperator`
A BashOperator that will execute :command:`pg_restore`.
"""
pg_env = _connection_to_environment(connection)
return BashOperator(task_id="pg_restore_schema",
bash_command=("[[ -f {{ params.dump_dir }}/{{ params.schema }}.dump ]] && " +
"pg_restore {{ params.dump_dir }}/{{ params.schema }}.dump"),
params={'schema': schema,
'dump_dir': dump_dir},
env=pg_env,
append_env=True)
[docs]
def q3c_index(connection, schema, table, ra='ra', dec='dec',
tablespace=None, overwrite=False):
"""Create a q3c index on `schema`.`table`.
Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
table : :class:`str`
The name of the table in `schema`.
ra : :class:`str`, optional
Name of the column containing Right Ascension, default 'ra'.
dec : :class:`str`, optional
Name of the column containing Declination, default 'dec'.
tablespace : :class:`str`, optional
Create the index in a specific tablespace if set.
overwrite : :class:`bool`, optional
If ``True`` replace any existing SQL template file.
Returns
-------
:class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`
A task to create a q3c index.
"""
sql_dir = ensure_sql()
sql_basename = "dlairflow.postgresql.q3c_index.sql"
sql_file = os.path.join(sql_dir, sql_basename)
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.q3c_index().
-- Call q3c_index(..., overwrite=True) to replace this file.
--
CREATE INDEX {{ params.table }}_q3c_ang2ipix
ON {{ params.schema }}.{{ params.table }} (q3c_ang2ipix("{{ params.ra }}", "{{ params.dec }}"))
WITH (fillfactor=100){%- if params.tablespace %} TABLESPACE {{ params.tablespace }}{%- endif -%};
CLUSTER {{ params.table }}_q3c_ang2ipix ON {{ params.schema }}.{{ params.table }};
"""
with open(sql_file, 'w') as s:
s.write(sql_data)
return _PostgresOperatorWrapper(sql=f"sql/{sql_basename}",
params={'schema': schema, 'table': table,
'ra': ra, 'dec': dec,
'tablespace': tablespace},
conn_id=connection,
task_id="q3c_index")
[docs]
def index_columns(connection, schema, table, columns, tablespace=None, overwrite=False):
"""Create "generic" indexes for a set of columns
Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
table : :class:`str`
The name of the table in `schema`.
columns : :class:`list`
A list of columns to index. See below for the possible entries in
the list of columns.
tablespace : :class:`str`, optional
Create the indexes in a specific tablespace if set.
overwrite : :class:`bool`, optional
If ``True`` replace any existing SQL template file.
Returns
-------
:class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`
A task to create several indexes.
Notes
-----
`columns` may be a list containing multiple types:
* :class:`str`: create an index on one column.
* :class:`tuple`: create an index on the set of columns in the tuple.
* :class:`dict`: create a *function* index. The key is the name of the function
and the value is the column that is the argument to the function.
* Any other type in `columns` will be ignored.
"""
sql_dir = ensure_sql()
sql_basename = "dlairflow.postgresql.index_columns.sql"
sql_file = os.path.join(sql_dir, sql_basename)
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.index_columns().
-- Call index_columns(..., overwrite=True) to replace this file.
--
{% for col in params.columns %}
{% if col is string -%}
CREATE INDEX {{ params.table }}_{{ col }}_idx
ON {{ params.schema }}.{{ params.table }} ("{{ col }}")
WITH (fillfactor=100){%- if params.tablespace %} TABLESPACE {{ params.tablespace }}{%- endif -%};
{% elif col is mapping -%}
{% for key, value in col.items() -%}
CREATE_INDEX {{ params.table }}_{{ key|replace('.', '_') }}_{{ value }}_idx
ON {{ params.schema }}.{{ params.table }} ({{ key }}({{ value }}))
WITH (fillfactor=100){%- if params.tablespace %} TABLESPACE {{ params.tablespace }}{%- endif -%};
{% endfor %}
{% elif col is sequence -%}
CREATE INDEX {{ params.table }}_{{ col|join("_") }}_idx
ON {{ params.schema }}.{{ params.table }} ("{{ col|join('", "') }}")
WITH (fillfactor=100){%- if params.tablespace %} TABLESPACE {{ params.tablespace }}{%- endif -%};
{% else -%}
-- Unknown type: {{ col }}.
{% endif -%}
{% endfor %}
"""
with open(sql_file, 'w') as s:
s.write(sql_data)
return _PostgresOperatorWrapper(sql=f"sql/{sql_basename}",
params={'schema': schema, 'table': table,
'columns': columns,
'tablespace': tablespace},
conn_id=connection,
task_id="index_columns")
[docs]
def primary_key(connection, schema, primary_keys, tablespace=None, overwrite=False):
"""Create a primary key on one or more tables in `schema`.
Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
primary_keys : :class:`dict`
A dictionary containing the of the table in `schema` mapped to the
primary key column(s). See below for details.
tablespace : :class:`str`, optional
Create the indexes in a specific tablespace if set.
overwrite : :class:`bool`, optional
If ``True`` replace any existing SQL template file.
Returns
-------
:class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`
A task to create a primary key.
Notes
-----
`primary_keys` may be a :class:`dict` containing multiple types:
* The key is the table name within `schema`.
* The value can be:
- :class:`str`: create a primary key on one column.
- :class:`tuple`: create a primary key on the set of columns in the tuple.
- Any other type will be ignored.
"""
sql_dir = ensure_sql()
sql_basename = "dlairflow.postgresql.primary_key.sql"
sql_file = os.path.join(sql_dir, sql_basename)
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.primary_key().
-- Call primary_key(..., overwrite=True) to replace this file.
--
{% for table, columns in params.primary_keys.items() %}
{% if columns is string -%}
ALTER TABLE {{ params.schema }}.{{ table }} ADD PRIMARY KEY ("{{ columns }}")
WITH (fillfactor=100){%- if params.tablespace %} USING INDEX TABLESPACE {{ params.tablespace }}{%- endif -%};
{% elif columns is sequence -%}
ALTER TABLE {{ params.schema }}.{{ table }} ADD PRIMARY KEY ("{{ columns|join('", "') }}")
WITH (fillfactor=100){%- if params.tablespace %} USING INDEX TABLESPACE {{ params.tablespace }}{%- endif -%};
{% else -%}
-- Unknown type: {{ columns }}.
{% endif -%}
{% endfor %}
"""
with open(sql_file, 'w') as s:
s.write(sql_data)
return _PostgresOperatorWrapper(sql=f"sql/{sql_basename}",
params={'schema': schema,
'primary_keys': primary_keys,
'tablespace': tablespace},
conn_id=connection,
task_id="primary_key")
[docs]
def truncate_table(connection, schema, table, restart=False, cascade=False,
overwrite=False):
"""Run ``TRUNCATE TABLE`` on one or more tables in `schema`.
Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
table : :class:`str` or :class:`list`
The table(s) to operate on.
restart : :class:`bool`, optional
If ``True``, any sequences associated with columns in the table(s) will
be reset. The default is not to reset such sequences.
cascade : :class:`bool`, optional
If ``True``, the ``TRUNCATE`` command will also truncate tables connected
by foreign key relationships. *This is extrememly dangerous!*
overwrite : :class:`bool`, optional
If ``True``, replace any existing SQL template file.
Returns
-------
:class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`
A task to run a ``TRUNCATE TABLE`` command.
Raises
------
:exc:`ValueError`
If `table` is not a string or list-like object.
"""
if isinstance(table, str):
tables = [table]
elif isinstance(table, (list, tuple, set, frozenset)):
tables = table
else:
raise ValueError("Unknown type for table, must be string or list-like!")
sql_dir = ensure_sql()
sql_basename = "dlairflow.postgresql.truncate_table.sql"
sql_file = os.path.join(sql_dir, sql_basename)
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.truncate_table().
-- Call truncate_table(..., overwrite=True) to replace this file.
--
TRUNCATE TABLE {% for table in params.tables -%}
{{ params.schema }}.{{ table }}{{ '' if loop.last else ', ' }}
{%- endfor %}
{% if params.restart -%}RESTART{%- else -%}CONTINUE{%- endif %} IDENTITY
{% if params.cascade -%}CASCADE{%- else -%}RESTRICT{%- endif %};
"""
with open(sql_file, 'w') as s:
s.write(sql_data)
return _PostgresOperatorWrapper(sql=f"sql/{sql_basename}",
params={'schema': schema,
'tables': tables,
'restart': restart,
'cascade': cascade},
conn_id=connection,
task_id="truncate_table")
[docs]
def vacuum_analyze(connection, schema, table, full=False, overwrite=False):
"""Run ``VACUUM`` and ``ANALYZE`` on one or more tables in `schema`.
Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
table : :class:`str` or :class:`list`
The table(s) to operate on.
full : :class:`bool`, optional
If ``True``, run ``VACUUM FULL``.
overwrite : :class:`bool`, optional
If ``True`` replace any existing SQL template file.
Returns
-------
:class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`
A task to run a ``VACUUM`` command.
Raises
------
:exc:`ValueError`
If `table` is not a string or list-like object.
Notes
-----
The returned :class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`
has `autocommit=True` set, which inhibits execution of SQL commands in a
transaction block. Normally a transaction block is a good thing, but ``VACUUM``
cannot be run in a transaction block.
"""
if isinstance(table, str):
tables = [table]
elif isinstance(table, (list, tuple, set, frozenset)):
tables = table
else:
raise ValueError("Unknown type for table, must be string or list-like!")
sql_dir = ensure_sql()
sql_basename = "dlairflow.postgresql.vacuum_analyze.sql"
sql_file = os.path.join(sql_dir, sql_basename)
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.vacuum_analyze().
-- Call vacuum_analyze(..., overwrite=True) to replace this file.
--
{% for table in params.tables %}
VACUUM {% if params.full -%}FULL{%- endif %} VERBOSE ANALYZE {{ params.schema }}.{{ table }};
{% endfor %}
"""
with open(sql_file, 'w') as s:
s.write(sql_data)
return _PostgresOperatorWrapper(sql=f"sql/{sql_basename}",
autocommit=True,
params={'schema': schema,
'tables': tables,
'full': full},
conn_id=connection,
task_id="vacuum_analyze")