"""Provides a hook to interact with a dbt project."""
from __future__ import annotations
import json
import logging
import sys
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
NamedTuple,
Optional,
Tuple,
Union,
)
from urllib.parse import urlparse
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection
if sys.version_info >= (3, 11):
from contextlib import chdir as chdir_ctx
else:
from contextlib_chdir import chdir as chdir_ctx
if TYPE_CHECKING:
from dbt.contracts.results import RunResult
from dbt.task.base import BaseTask
from airflow_dbt_python.hooks.remote import DbtRemoteHook
from airflow_dbt_python.utils.configs import BaseConfig
from airflow_dbt_python.utils.url import URLLike
DbtRemoteHooksDict = Dict[Tuple[str, Optional[str]], DbtRemoteHook]
[docs]
class DbtTaskResult(NamedTuple):
"""A tuple returned after a dbt task executes.
Attributes:
success: Whether the task succeeded or not.
run_results: Results from the dbt task, if available.
artifacts: A dictionary of saved dbt artifacts. It may be empty.
"""
success: bool
run_results: Optional[RunResult]
artifacts: dict[str, Any]
[docs]
class DbtConnectionParam(NamedTuple):
"""A tuple indicating connection parameters relevant to dbt.
Attributes:
name: The name of the connection parameter. This name will be used to get the
parameter from an Airflow Connection or its extras.
store_override_name: A new name for the connection parameter. If not None, this
is the name used in a dbt profiles.
default: A default value if the parameter is not found.
"""
name: str
store_override_name: Optional[str] = None
default: Optional[Any] = None
@property
def override_name(self):
"""Returns the override_name if defined, otherwise defaults to name.
>>> DbtConnectionParam("login", "user").override_name
'user'
>>> DbtConnectionParam("port").override_name
'port'
"""
if self.store_override_name is None:
return self.name
return self.store_override_name
[docs]
class DbtTemporaryDirectory(TemporaryDirectory):
"""A wrapper on TemporaryDirectory for older versions of Python.
Support for ignore_cleanup_errors was added in Python 3.10. There is a very obscure
error that can happen when cleaning up a directory, even though everything should
be cleaned. We would like to use ignore_cleanup_errors to provide clean up on a
best-effort basis. For the time being, we are addressing this only for Python>=3.10.
"""
def __init__(self, suffix=None, prefix=None, dir=None, ignore_cleanup_errors=True):
if sys.version_info.minor < 10 and sys.version_info.major == 3:
super().__init__(suffix=suffix, prefix=prefix, dir=dir)
else:
super().__init__(
suffix=suffix,
prefix=prefix,
dir=dir,
ignore_cleanup_errors=ignore_cleanup_errors,
)
[docs]
class DbtHook(BaseHook):
"""A hook to interact with dbt.
Allows for running dbt tasks and provides required configurations for each task.
"""
conn_name_attr = "dbt_conn_id"
default_conn_name = "dbt_default"
conn_type = "dbt"
hook_name = "dbt Hook"
conn_params: list[Union[DbtConnectionParam, str]] = [
DbtConnectionParam("conn_type", "type"),
"host",
DbtConnectionParam("conn_id", "dbname"),
"schema",
DbtConnectionParam("login", "user"),
"password",
"port",
]
conn_extra_params: list[Union[DbtConnectionParam, str]] = []
def __init__(
self,
*args,
dbt_conn_id: Optional[str] = default_conn_name,
project_conn_id: Optional[str] = None,
profiles_conn_id: Optional[str] = None,
**kwargs,
):
self.remotes: DbtRemoteHooksDict = {}
self.dbt_conn_id = dbt_conn_id
self.project_conn_id = project_conn_id
self.profiles_conn_id = profiles_conn_id
super().__init__(*args, **kwargs)
[docs]
def get_remote(self, scheme: str, conn_id: Optional[str]) -> DbtRemoteHook:
"""Get a remote to interact with dbt files.
RemoteHooks are defined by the scheme we are looking for and an optional
connection id if we are looking to interface with any Airflow hook that
uses a connection.
"""
from .remote import get_remote
try:
return self.remotes[(scheme, conn_id)]
except KeyError:
remote = get_remote(scheme, conn_id)
self.remotes[(scheme, conn_id)] = remote
return remote
[docs]
def download_dbt_profiles(
self,
profiles_dir: URLLike,
destination: URLLike,
) -> Path:
"""Pull a dbt profiles.yml file from a given profiles_dir.
This operation is delegated to a DbtRemoteHook. An optional connection id is
supported for remotes that require it.
"""
scheme = urlparse(str(profiles_dir)).scheme
remote = self.get_remote(scheme, self.project_conn_id)
return remote.download_dbt_profiles(profiles_dir, destination)
[docs]
def download_dbt_project(
self,
project_dir: URLLike,
destination: URLLike,
) -> Path:
"""Pull a dbt project from a given project_dir.
This operation is delegated to a DbtRemoteHook. An optional connection id is
supported for remotes that require it.
"""
scheme = urlparse(str(project_dir)).scheme
remote = self.get_remote(scheme, self.project_conn_id)
return remote.download_dbt_project(project_dir, destination)
[docs]
def upload_dbt_project(
self,
project_dir: URLLike,
destination: URLLike,
replace: bool = False,
delete_before: bool = False,
) -> None:
"""Push a dbt project from a given project_dir.
This operation is delegated to a DbtRemoteHook. An optional connection id is
supported for remotes that require it.
"""
scheme = urlparse(str(destination)).scheme
remote = self.get_remote(scheme, self.project_conn_id)
return remote.upload_dbt_project(
project_dir, destination, replace=replace, delete_before=delete_before
)
[docs]
def run_dbt_task(
self,
command: str,
upload_dbt_project: bool = False,
delete_before_upload: bool = False,
replace_on_upload: bool = False,
artifacts: Optional[Iterable[str]] = None,
env_vars: Optional[Dict[str, Any]] = None,
write_perf_info: bool = False,
**kwargs,
) -> DbtTaskResult:
"""Run a dbt task with a given configuration and return the results.
The configuration used determines the task that will be ran.
Returns:
A tuple containing a boolean indicating success and optionally the results
of running the dbt command.
"""
from dbt.adapters.factory import adapter_management
from dbt.task.base import get_nearest_project_dir
from dbt.task.clean import CleanTask
from dbt.task.deps import DepsTask
from dbt.tracking import track_run
config = self.get_dbt_task_config(command, **kwargs)
extra_target = self.get_dbt_target_from_connection(config.target)
with self.dbt_directory(
config,
upload_dbt_project=upload_dbt_project,
delete_before_upload=delete_before_upload,
replace_on_upload=replace_on_upload,
env_vars=env_vars,
) as dbt_dir:
# When creating tasks via from_args, dbt switches to the project directory.
# We have to do that here as we are not using from_args.
nearest_project_dir = get_nearest_project_dir(config.project_dir)
with chdir_ctx(nearest_project_dir):
self.ensure_profiles(config)
with adapter_management():
task, runtime_config = config.create_dbt_task(
extra_target, write_perf_info
)
requires_profile = isinstance(task, (CleanTask, DepsTask))
self.setup_dbt_logging(task, config.debug)
if runtime_config is not None and not requires_profile:
# The deps command installs the dependencies, which means they
# may not exist before deps runs and the following would raise a
# CompilationError.
runtime_config.load_dependencies()
with track_run(task):
results = task.run()
success = task.interpret_results(results)
if artifacts is None:
return DbtTaskResult(success, results, {})
saved_artifacts = {}
for artifact in artifacts:
artifact_path = Path(dbt_dir) / "target" / artifact
if not artifact_path.exists():
self.log.warning(
"Required dbt artifact %s was not found. "
"Perhaps dbt failed and couldn't generate it.",
artifact,
)
continue
with open(artifact_path) as artifact_file:
json_artifact = json.load(artifact_file)
saved_artifacts[artifact] = json_artifact
return DbtTaskResult(success, results, saved_artifacts)
[docs]
def get_dbt_task_config(self, command: str, **config_kwargs) -> BaseConfig:
"""Initialize a configuration for given dbt command with given kwargs."""
from airflow_dbt_python.utils.configs import ConfigFactory
return ConfigFactory.from_str(command).create_config(**config_kwargs)
[docs]
@contextmanager
def dbt_directory(
self,
config,
upload_dbt_project: bool = False,
delete_before_upload: bool = False,
replace_on_upload: bool = False,
env_vars: Optional[Dict[str, Any]] = None,
) -> Iterator[str]:
"""Provides a temporary directory to execute dbt.
Creates a temporary directory for dbt to run in and prepares the dbt files
if they need to be pulled from S3. If a S3 backend is being used, and
self.upload_dbt_project is True, before leaving the temporary directory, we push
back the project to S3. Pushing back a project enables commands like deps or
docs generate.
Yields:
The temporary directory's name.
"""
from airflow_dbt_python.utils.env import update_environment
store_profiles_dir = config.profiles_dir
store_project_dir = config.project_dir
with update_environment(env_vars):
with DbtTemporaryDirectory(prefix="airflow_tmp") as tmp_dir:
self.log.info("Initializing temporary directory: %s", tmp_dir)
try:
project_dir, profiles_dir = self.prepare_directory(
tmp_dir,
store_project_dir,
store_profiles_dir,
)
except Exception as e:
raise AirflowException(
"Failed to prepare temporary directory for dbt execution"
) from e
config.project_dir = project_dir
config.profiles_dir = profiles_dir
if getattr(config, "state", None) is not None:
state = Path(getattr(config, "state", ""))
# Since we are running in a temporary directory, we need to make
# state paths relative to this temporary directory.
if not state.is_absolute():
setattr(config, "state", str(Path(tmp_dir) / state))
yield tmp_dir
if upload_dbt_project is True:
self.log.info("Uploading dbt project to: %s", store_project_dir)
self.upload_dbt_project(
tmp_dir,
store_project_dir,
replace=replace_on_upload,
delete_before=delete_before_upload,
)
config.profiles_dir = store_profiles_dir
config.project_dir = store_project_dir
[docs]
def prepare_directory(
self,
tmp_dir: str,
project_dir: URLLike,
profiles_dir: Optional[URLLike] = None,
) -> tuple[str, Optional[str]]:
"""Prepares a dbt directory for execution of a dbt task.
Preparation involves downloading the required dbt project files and
profiles.yml.
"""
project_dir_path = self.download_dbt_project(
project_dir,
tmp_dir,
)
new_project_dir = str(project_dir_path) + "/"
if (project_dir_path / "profiles.yml").exists():
# We may have downloaded the profiles.yml file together
# with the project.
return new_project_dir, new_project_dir
if profiles_dir is not None:
profiles_file_path = self.download_dbt_profiles(
profiles_dir,
tmp_dir,
)
new_profiles_dir = str(profiles_file_path.parent) + "/"
else:
new_profiles_dir = None
return new_project_dir, new_profiles_dir
[docs]
def setup_dbt_logging(self, task: BaseTask, debug: Optional[bool]):
"""Setup dbt logging.
Starting with dbt v1, dbt initializes two loggers: default_file and
default_stdout. As these are initialized by the CLI app, we need to
initialize them here.
"""
from dbt.events.logging import setup_event_logger
from dbt.flags import get_flags
flags = get_flags()
setup_event_logger(flags)
configured_file = logging.getLogger("configured_file")
file_log = logging.getLogger("file_log")
stdout_log = logging.getLogger("stdout_log")
stdout_log.propagate = True
if not debug:
# We have to do this after setting logs up as dbt hasn't
# configured the loggers before the call to setup_event_logger.
# In the future, handlers may also be cleared or setup to use Airflow's.
file_log.setLevel("INFO")
file_log.propagate = False
configured_file.setLevel("INFO")
configured_file.propagate = False
[docs]
def ensure_profiles(self, config: BaseConfig):
"""Ensure a profiles file exists."""
if config.profiles_dir is not None:
# We expect one to exist given that we have passed a profiles_dir.
return
profiles_path = Path.home() / ".dbt/profiles.yml"
config.profiles_dir = str(profiles_path.parent)
if not profiles_path.exists():
profiles_path.parent.mkdir(exist_ok=True)
with profiles_path.open("w", encoding="utf-8") as f:
f.write("flags:\n send_anonymous_usage_stats: false\n")
[docs]
def get_dbt_target_from_connection(
self, target: Optional[str]
) -> Optional[dict[str, Any]]:
"""Return a dictionary of connection details to use as a dbt target.
The connection details are fetched from an Airflow connection identified by
target or self.dbt_conn_id.
Args:
target: The target name to use as an Airflow connection ID. If ommitted, we
will use self.dbt_conn_id.
Returns:
A dictionary with a configuration for a dbt target, or None if a matching
Airflow connection is not found for given dbt target.
"""
conn_id = target or self.dbt_conn_id
if conn_id is None:
return None
try:
conn = self.get_connection(conn_id)
except AirflowException:
self.log.debug(
"No Airflow connection matching dbt target %s was found.", target
)
return None
details = self.get_dbt_details_from_connection(conn)
return {conn_id: details}
[docs]
def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]:
"""Extract dbt connection details from Airflow Connection.
dbt connection details may be present as Airflow Connection attributes or in the
Connection's extras. This class' conn_params and conn_extra_params will be used
to fetch required attributes from attributes and extras respectively. If
conn_extra_params is empty, we merge parameters with all extras.
Subclasses may override this class attributes to narrow down the connection
details for a specific dbt target (like Postgres, or Redshift).
Args:
conn: The Airflow Connection to extract dbt connection details from.
Returns:
A dictionary of dbt connection details.
"""
dbt_details = {}
for param in self.conn_params:
if isinstance(param, DbtConnectionParam):
key = param.override_name
value = getattr(conn, param.name, param.default)
else:
key = param
value = getattr(conn, key, None)
if value is None:
continue
dbt_details[key] = value
extra = conn.extra_dejson
if not self.conn_extra_params:
return {**dbt_details, **extra}
for param in self.conn_extra_params:
if isinstance(param, DbtConnectionParam):
key = param.override_name
value = extra.get(param.name, param.default)
else:
key = param
value = extra.get(key, None)
if value is None:
continue
dbt_details[key] = value
return dbt_details
[docs]
class DbtPostgresHook(DbtHook):
"""A hook to interact with dbt using a Postgres connection."""
conn_type = "postgres"
hook_name = "dbt Postgres Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", "postgres"),
"host",
"schema",
DbtConnectionParam("login", "user"),
"password",
"port",
]
conn_extra_params = [
"dbname",
"threads",
"keepalives_idle",
"connect_timeout",
"retries",
"search_path",
"role",
"sslmode",
]
[docs]
class DbtRedshiftHook(DbtPostgresHook):
"""A hook to interact with dbt using a Redshift connection."""
conn_type = "redshift"
hook_name = "dbt Redshift Hook"
conn_extra_params = DbtPostgresHook.conn_extra_params + [
"ra3_node",
"iam_profile",
"iam_duration_secons",
"autocreate",
"db_groups",
]
[docs]
class DbtSnowflakeHook(DbtHook):
"""A hook to interact with dbt using a Snowflake connection."""
conn_type = "snowflake"
hook_name = "dbt Snowflake Hook"
conn_params = [
DbtConnectionParam("conn_type", "type", "postgres"),
"host",
"schema",
DbtConnectionParam("login", "user"),
"password",
]
conn_extra_params = [
"account",
"role",
"database",
"warehouse",
"threads",
"client_session_keep_alive",
"query_tag",
"connect_retries",
"connect_timeout",
"retry_on_database_errors",
"retry_all",
]