# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import collections
import functools
import getpass
import os
import pathlib
import platform
import re
import socket
import ssl as ssl_module
import stat
import struct
import time
import typing
import urllib.parse
import warnings
from . import compat
from . import exceptions
from . import protocol
_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
'user',
'password',
'database',
'ssl',
'ssl_is_advisory',
'connect_timeout',
'server_settings',
])
_ClientConfiguration = collections.namedtuple(
'ConnectionConfiguration',
[
'command_timeout',
'statement_cache_size',
'max_cached_statement_lifetime',
'max_cacheable_statement_size',
])
_system = platform.uname().system
if _system == 'Windows':
PGPASSFILE = 'pgpass.conf'
else:
PGPASSFILE = '.pgpass'
def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:
passtab = []
try:
if not passfile.exists():
return []
if not passfile.is_file():
warnings.warn(
'password file {!r} is not a plain file'.format(passfile))
return []
if _system != 'Windows':
if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO):
warnings.warn(
'password file {!r} has group or world access; '
'permissions should be u=rw (0600) or less'.format(
passfile))
return []
with passfile.open('rt') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'):
# Skip empty lines and comments.
continue
# Backslash escapes both itself and the colon,
# which is a record separator.
line = line.replace(R'\\', '\n')
passtab.append(tuple(
p.replace('\n', R'\\')
for p in re.split(r'(?<!\\):', line, maxsplit=4)
))
except IOError:
pass
return passtab
def _read_password_from_pgpass(
*, passfile: typing.Optional[pathlib.Path],
hosts: typing.List[str],
ports: typing.List[int],
database: str,
user: str):
"""Parse the pgpass file and return the matching password.
:return:
Password string, if found, ``None`` otherwise.
"""
passtab = _read_password_file(passfile)
if not passtab:
return None
for host, port in zip(hosts, ports):
if host.startswith('/'):
# Unix sockets get normalized into 'localhost'
host = 'localhost'
for phost, pport, pdatabase, puser, ppassword in passtab:
if phost != '*' and phost != host:
continue
if pport != '*' and pport != str(port):
continue
if pdatabase != '*' and pdatabase != database:
continue
if puser != '*' and puser != user:
continue
# Found a match.
return ppassword
return None
def _validate_port_spec(hosts, port):
if isinstance(port, list):
# If there is a list of ports, its length must
# match that of the host list.
if len(port) != len(hosts):
raise exceptions.InterfaceError(
'could not match {} port numbers to {} hosts'.format(
len(port), len(hosts)))
else:
port = [port for _ in range(len(hosts))]
return port
def _parse_hostlist(hostlist, port):
if ',' in hostlist:
# A comma-separated list of host addresses.
hostspecs = hostlist.split(',')
else:
hostspecs = [hostlist]
hosts = []
hostlist_ports = []
if not port:
portspec = os.environ.get('PGPORT')
if portspec:
if ',' in portspec:
default_port = [int(p) for p in portspec.split(',')]
else:
default_port = int(portspec)
else:
default_port = 5432
default_port = _validate_port_spec(hostspecs, default_port)
else:
port = _validate_port_spec(hostspecs, port)
for i, hostspec in enumerate(hostspecs):
addr, _, hostspec_port = hostspec.partition(':')
hosts.append(addr)
if not port:
if hostspec_port:
hostlist_ports.append(int(hostspec_port))
else:
hostlist_ports.append(default_port[i])
if not port:
port = hostlist_ports
return hosts, port
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
connect_timeout, server_settings):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
if dsn:
parsed = urllib.parse.urlparse(dsn)
if parsed.scheme not in {'postgresql', 'postgres'}:
raise ValueError(
'invalid DSN: scheme is expected to be either '
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
if not host and parsed.netloc:
if '@' in parsed.netloc:
auth, _, hostspec = parsed.netloc.partition('@')
else:
hostspec = parsed.netloc
if hostspec:
host, port = _parse_hostlist(hostspec, port)
if parsed.path and database is None:
database = parsed.path
if database.startswith('/'):
database = database[1:]
if parsed.username and user is None:
user = parsed.username
if parsed.password and password is None:
password = parsed.password
if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]
if 'port' in query:
val = query.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]
if 'host' in query:
val = query.pop('host')
if not host and val:
host, port = _parse_hostlist(val, port)
if 'dbname' in query:
val = query.pop('dbname')
if database is None:
database = val
if 'database' in query:
val = query.pop('database')
if database is None:
database = val
if 'user' in query:
val = query.pop('user')
if user is None:
user = val
if 'password' in query:
val = query.pop('password')
if password is None:
password = val
if 'passfile' in query:
val = query.pop('passfile')
if passfile is None:
passfile = val
if 'sslmode' in query:
val = query.pop('sslmode')
if ssl is None:
ssl = val
if query:
if server_settings is None:
server_settings = query
else:
server_settings = {**query, **server_settings}
if not host:
hostspec = os.environ.get('PGHOST')
if hostspec:
host, port = _parse_hostlist(hostspec, port)
if not host:
auth_hosts = ['localhost']
if _system == 'Windows':
host = ['localhost']
else:
host = ['/run/postgresql', '/var/run/postgresql',
'/tmp', '/private/tmp', 'localhost']
if not isinstance(host, list):
host = [host]
if auth_hosts is None:
auth_hosts = host
if not port:
portspec = os.environ.get('PGPORT')
if portspec:
if ',' in portspec:
port = [int(p) for p in portspec.split(',')]
else:
port = int(portspec)
else:
port = 5432
elif isinstance(port, (list, tuple)):
port = [int(p) for p in port]
else:
port = int(port)
port = _validate_port_spec(host, port)
if user is None:
user = os.getenv('PGUSER')
if not user:
user = getpass.getuser()
if password is None:
password = os.getenv('PGPASSWORD')
if database is None:
database = os.getenv('PGDATABASE')
if database is None:
database = user
if user is None:
raise exceptions.InterfaceError(
'could not determine user name to connect with')
if database is None:
raise exceptions.InterfaceError(
'could not determine database name to connect to')
if password is None:
if passfile is None:
Loading ...