# Copyright (c) 2012, 2018, Oracle and/or its affiliates. All rights reserved. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """Module implementing low-level socket communication with MySQL servers. """ from collections import deque import os import socket import struct import sys import zlib try: import ssl TLS_VERSIONS = { "TLSv1": ssl.PROTOCOL_TLSv1, "TLSv1.1": ssl.PROTOCOL_TLSv1_1, "TLSv1.2": ssl.PROTOCOL_TLSv1_2} # TLSv1.3 included in PROTOCOL_TLS, but PROTOCOL_TLS is not included on 3.4 if hasattr(ssl, "PROTOCOL_TLS"): TLS_VERSIONS["TLSv1.3"] = ssl.PROTOCOL_TLS # pylint: disable=E1101 else: TLS_VERSIONS["TLSv1.3"] = ssl.PROTOCOL_SSLv23 # Alias of PROTOCOL_TLS if hasattr(ssl, "HAS_TLSv1_3") and ssl.HAS_TLSv1_3: TLS_V1_3_SUPPORTED = True else: TLS_V1_3_SUPPORTED = False except: # If import fails, we don't have SSL support. TLS_V1_3_SUPPORTED = False pass from . import constants, errors from .errors import InterfaceError from .catch23 import PY2, init_bytearray, struct_unpack def _strioerror(err): """Reformat the IOError error message This function reformats the IOError error message. """ if not err.errno: return str(err) return '{errno} {strerr}'.format(errno=err.errno, strerr=err.strerror) def _prepare_packets(buf, pktnr): """Prepare a packet for sending to the MySQL server""" pkts = [] pllen = len(buf) maxpktlen = constants.MAX_PACKET_LENGTH while pllen > maxpktlen: pkts.append(b'\xff\xff\xff' + struct.pack(' 255: self._packet_number = 0 return self._packet_number @property def next_compressed_packet_number(self): """Increments the compressed packet number""" self._compressed_packet_number = self._compressed_packet_number + 1 if self._compressed_packet_number > 255: self._compressed_packet_number = 0 return self._compressed_packet_number def open_connection(self): """Open the socket""" raise NotImplementedError def get_address(self): """Get the location of the socket""" raise NotImplementedError def shutdown(self): """Shut down the socket before closing it""" try: self.sock.shutdown(socket.SHUT_RDWR) self.sock.close() del self._packet_queue except (socket.error, AttributeError): pass def close_connection(self): """Close the socket""" try: self.sock.close() del self._packet_queue except (socket.error, AttributeError): pass def __del__(self): self.shutdown() def send_plain(self, buf, packet_number=None, compressed_packet_number=None): """Send packets to the MySQL server""" if packet_number is None: self.next_packet_number # pylint: disable=W0104 else: self._packet_number = packet_number packets = _prepare_packets(buf, self._packet_number) for packet in packets: try: if PY2: self.sock.sendall(buffer(packet)) # pylint: disable=E0602 else: self.sock.sendall(packet) except IOError as err: raise errors.OperationalError( errno=2055, values=(self.get_address(), _strioerror(err))) except AttributeError: raise errors.OperationalError(errno=2006) send = send_plain def send_compressed(self, buf, packet_number=None, compressed_packet_number=None): """Send compressed packets to the MySQL server""" if packet_number is None: self.next_packet_number # pylint: disable=W0104 else: self._packet_number = packet_number if compressed_packet_number is None: self.next_compressed_packet_number # pylint: disable=W0104 else: self._compressed_packet_number = compressed_packet_number pktnr = self._packet_number pllen = len(buf) zpkts = [] maxpktlen = constants.MAX_PACKET_LENGTH if pllen > maxpktlen: pkts = _prepare_packets(buf, pktnr) if PY2: tmpbuf = bytearray() for pkt in pkts: tmpbuf += pkt tmpbuf = buffer(tmpbuf) # pylint: disable=E0602 else: tmpbuf = b''.join(pkts) del pkts zbuf = zlib.compress(tmpbuf[:16384]) header = (struct.pack(' maxpktlen: zbuf = zlib.compress(tmpbuf[:maxpktlen]) header = (struct.pack(' 50: zbuf = zlib.compress(pkt) zpkts.append(struct.pack(' 0: raise errors.InterfaceError(errno=2013) packet_view = packet_view[read:] rest -= read return packet except IOError as err: raise errors.OperationalError( errno=2055, values=(self.get_address(), _strioerror(err))) def recv_py26_plain(self): """Receive packets from the MySQL server""" try: # Read the header of the MySQL packet, 4 bytes header = bytearray(b'') header_len = 0 while header_len < 4: chunk = self.sock.recv(4 - header_len) if not chunk: raise errors.InterfaceError(errno=2013) header += chunk header_len = len(header) # Save the packet number and payload length self._packet_number = header[3] payload_len = struct_unpack(" 0: chunk = self.sock.recv(rest) if not chunk: raise errors.InterfaceError(errno=2013) payload += chunk rest = payload_len - len(payload) return header + payload except IOError as err: raise errors.OperationalError( errno=2055, values=(self.get_address(), _strioerror(err))) if sys.version_info[0:2] == (2, 6): recv = recv_py26_plain recv_plain = recv_py26_plain else: recv = recv_plain def _split_zipped_payload(self, packet_bunch): """Split compressed payload""" while packet_bunch: if PY2: payload_length = struct.unpack_from( " 1: tls_version = tls_versions[1] ssl_protocol = TLS_VERSIONS[tls_version] context = ssl.SSLContext(ssl_protocol) if tls_version == "TLSv1.3": if "TLSv1.2" not in tls_versions: context.options |= ssl.OP_NO_TLSv1_2 if "TLSv1.1" not in tls_versions: context.options |= ssl.OP_NO_TLSv1_1 if "TLSv1" not in tls_versions: context.options |= ssl.OP_NO_TLSv1 context.check_hostname = False context.verify_mode = cert_reqs context.load_default_certs() if ca: try: context.load_verify_locations(ca) except (IOError, ssl.SSLError) as err: self.sock.close() raise InterfaceError( "Invalid CA Certificate: {}".format(err)) if cert: try: context.load_cert_chain(cert, key) except (IOError, ssl.SSLError) as err: self.sock.close() raise InterfaceError( "Invalid Certificate/Key: {}".format(err)) if cipher_suites: context.set_ciphers(cipher_suites) if hasattr(self, "server_host"): self.sock = context.wrap_socket( self.sock, server_hostname=self.server_host) else: self.sock = context.wrap_socket(self.sock) if verify_identity: context.check_hostname = True hostnames = [self.server_host] if os.name == 'nt' and self.server_host == 'localhost': hostnames = ['localhost', '127.0.0.1'] aliases = socket.gethostbyaddr(self.server_host) hostnames.extend([aliases[0]] + aliases[1]) match_found = False errs = [] for hostname in hostnames: try: ssl.match_hostname(self.sock.getpeercert(), hostname) except ssl.CertificateError as err: errs.append(str(err)) else: match_found = True break if not match_found: self.sock.close() raise InterfaceError("Unable to verify server identity: {}" "".format(", ".join(errs))) except NameError: raise errors.NotSupportedError( "Python installation has no SSL support") except (ssl.SSLError, IOError) as err: raise errors.InterfaceError( errno=2055, values=(self.get_address(), _strioerror(err))) except ssl.CertificateError as err: raise errors.InterfaceError(str(err)) except NotImplementedError as err: raise errors.InterfaceError(str(err)) # pylint: enable=C0103,E1101 class MySQLUnixSocket(BaseMySQLSocket): """MySQL socket class using UNIX sockets Opens a connection through the UNIX socket of the MySQL Server. """ def __init__(self, unix_socket='/tmp/mysql.sock'): super(MySQLUnixSocket, self).__init__() self.unix_socket = unix_socket def get_address(self): return self.unix_socket def open_connection(self): try: self.sock = socket.socket(socket.AF_UNIX, # pylint: disable=E1101 socket.SOCK_STREAM) self.sock.settimeout(self._connection_timeout) self.sock.connect(self.unix_socket) except IOError as err: raise errors.InterfaceError( errno=2002, values=(self.get_address(), _strioerror(err))) except Exception as err: raise errors.InterfaceError(str(err)) class MySQLTCPSocket(BaseMySQLSocket): """MySQL socket class using TCP/IP Opens a TCP/IP connection to the MySQL Server. """ def __init__(self, host='127.0.0.1', port=3306, force_ipv6=False): super(MySQLTCPSocket, self).__init__() self.server_host = host self.server_port = port self.force_ipv6 = force_ipv6 self._family = 0 def get_address(self): return "{0}:{1}".format(self.server_host, self.server_port) def open_connection(self): """Open the TCP/IP connection to the MySQL server """ # Get address information addrinfo = [None] * 5 try: addrinfos = socket.getaddrinfo(self.server_host, self.server_port, 0, socket.SOCK_STREAM, socket.SOL_TCP) # If multiple results we favor IPv4, unless IPv6 was forced. for info in addrinfos: if self.force_ipv6 and info[0] == socket.AF_INET6: addrinfo = info break elif info[0] == socket.AF_INET: addrinfo = info break if self.force_ipv6 and addrinfo[0] is None: raise errors.InterfaceError( "No IPv6 address found for {0}".format(self.server_host)) if addrinfo[0] is None: addrinfo = addrinfos[0] except IOError as err: raise errors.InterfaceError( errno=2003, values=(self.get_address(), _strioerror(err))) else: (self._family, socktype, proto, _, sockaddr) = addrinfo # Instanciate the socket and connect try: self.sock = socket.socket(self._family, socktype, proto) self.sock.settimeout(self._connection_timeout) self.sock.connect(sockaddr) except IOError as err: raise errors.InterfaceError( errno=2003, values=(self.get_address(), _strioerror(err))) except Exception as err: raise errors.OperationalError(str(err))