mirror of https://github.com/sgoudham/Enso-Bot.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
763 lines
29 KiB
Python
763 lines
29 KiB
Python
5 years ago
|
# Copyright (c) 2009, 2019, Oracle and/or its affiliates. All rights reserved.
|
||
|
#
|
||
|
# This program is free software; you can redistribute it and/or modify
|
||
|
# it under the terms of the GNU General Public License, version 2.0, as
|
||
|
# published by the Free Software Foundation.
|
||
|
#
|
||
|
# This program is also distributed with certain software (including
|
||
|
# but not limited to OpenSSL) that is licensed under separate terms,
|
||
|
# as designated in a particular file or component or in included license
|
||
|
# documentation. The authors of MySQL hereby grant you an
|
||
|
# additional permission to link the program and your derivative works
|
||
|
# with the separately licensed software that they have included with
|
||
|
# MySQL.
|
||
|
#
|
||
|
# Without limiting anything contained in the foregoing, this file,
|
||
|
# which is part of MySQL Connector/Python, is also subject to the
|
||
|
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||
|
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||
|
#
|
||
|
# This program is distributed in the hope that it will be useful, but
|
||
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||
|
# See the GNU General Public License, version 2.0, for more details.
|
||
|
#
|
||
|
# You should have received a copy of the GNU General Public License
|
||
|
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||
|
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||
|
|
||
|
"""Implements the MySQL Client/Server protocol
|
||
|
"""
|
||
|
|
||
|
import struct
|
||
|
import datetime
|
||
|
from decimal import Decimal
|
||
|
|
||
|
from .constants import (
|
||
|
FieldFlag, ServerCmd, FieldType, ClientFlag)
|
||
|
from . import errors, utils
|
||
|
from .authentication import get_auth_plugin
|
||
|
from .catch23 import PY2, struct_unpack
|
||
|
from .errors import DatabaseError, get_exception
|
||
|
|
||
|
PROTOCOL_VERSION = 10
|
||
|
|
||
|
|
||
|
class MySQLProtocol(object):
|
||
|
"""Implements MySQL client/server protocol
|
||
|
|
||
|
Create and parses MySQL packets.
|
||
|
"""
|
||
|
|
||
|
def _connect_with_db(self, client_flags, database):
|
||
|
"""Prepare database string for handshake response"""
|
||
|
if client_flags & ClientFlag.CONNECT_WITH_DB and database:
|
||
|
return database.encode('utf8') + b'\x00'
|
||
|
return b'\x00'
|
||
|
|
||
|
def _auth_response(self, client_flags, username, password, database,
|
||
|
auth_plugin, auth_data, ssl_enabled):
|
||
|
"""Prepare the authentication response"""
|
||
|
if not password:
|
||
|
return b'\x00'
|
||
|
|
||
|
try:
|
||
|
auth = get_auth_plugin(auth_plugin)(
|
||
|
auth_data,
|
||
|
username=username, password=password, database=database,
|
||
|
ssl_enabled=ssl_enabled)
|
||
|
plugin_auth_response = auth.auth_response()
|
||
|
except (TypeError, errors.InterfaceError) as exc:
|
||
|
raise errors.InterfaceError(
|
||
|
"Failed authentication: {0}".format(str(exc)))
|
||
|
|
||
|
if client_flags & ClientFlag.SECURE_CONNECTION:
|
||
|
resplen = len(plugin_auth_response)
|
||
|
auth_response = struct.pack('<B', resplen) + plugin_auth_response
|
||
|
else:
|
||
|
auth_response = plugin_auth_response + b'\x00'
|
||
|
return auth_response
|
||
|
|
||
|
def make_auth(self, handshake, username=None, password=None, database=None,
|
||
|
charset=45, client_flags=0,
|
||
|
max_allowed_packet=1073741824, ssl_enabled=False,
|
||
|
auth_plugin=None, conn_attrs=None):
|
||
|
"""Make a MySQL Authentication packet"""
|
||
|
|
||
|
try:
|
||
|
auth_data = handshake['auth_data']
|
||
|
auth_plugin = auth_plugin or handshake['auth_plugin']
|
||
|
except (TypeError, KeyError) as exc:
|
||
|
raise errors.ProgrammingError(
|
||
|
"Handshake misses authentication info ({0})".format(exc))
|
||
|
|
||
|
if not username:
|
||
|
username = b''
|
||
|
try:
|
||
|
username_bytes = username.encode('utf8') # pylint: disable=E1103
|
||
|
except AttributeError:
|
||
|
# Username is already bytes
|
||
|
username_bytes = username
|
||
|
packet = struct.pack('<IIH{filler}{usrlen}sx'.format(
|
||
|
filler='x' * 22, usrlen=len(username_bytes)),
|
||
|
client_flags, max_allowed_packet, charset,
|
||
|
username_bytes)
|
||
|
|
||
|
packet += self._auth_response(client_flags, username, password,
|
||
|
database,
|
||
|
auth_plugin,
|
||
|
auth_data, ssl_enabled)
|
||
|
|
||
|
packet += self._connect_with_db(client_flags, database)
|
||
|
|
||
|
if client_flags & ClientFlag.PLUGIN_AUTH:
|
||
|
packet += auth_plugin.encode('utf8') + b'\x00'
|
||
|
|
||
|
if (client_flags & ClientFlag.CONNECT_ARGS) and conn_attrs is not None:
|
||
|
packet += self.make_conn_attrs(conn_attrs)
|
||
|
|
||
|
return packet
|
||
|
|
||
|
def make_conn_attrs(self, conn_attrs):
|
||
|
"""Encode the connection attributes"""
|
||
|
for attr_name in conn_attrs:
|
||
|
if conn_attrs[attr_name] is None:
|
||
|
conn_attrs[attr_name] = ""
|
||
|
conn_attrs_len = (
|
||
|
sum([len(x) + len(conn_attrs[x]) for x in conn_attrs]) +
|
||
|
len(conn_attrs.keys()) + len(conn_attrs.values()))
|
||
|
|
||
|
conn_attrs_packet = struct.pack('<B', conn_attrs_len)
|
||
|
for attr_name in conn_attrs:
|
||
|
conn_attrs_packet += struct.pack('<B', len(attr_name))
|
||
|
conn_attrs_packet += attr_name.encode('utf8')
|
||
|
conn_attrs_packet += struct.pack('<B', len(conn_attrs[attr_name]))
|
||
|
conn_attrs_packet += conn_attrs[attr_name].encode('utf8')
|
||
|
return conn_attrs_packet
|
||
|
|
||
|
def make_auth_ssl(self, charset=45, client_flags=0,
|
||
|
max_allowed_packet=1073741824):
|
||
|
"""Make a SSL authentication packet"""
|
||
|
return utils.int4store(client_flags) + \
|
||
|
utils.int4store(max_allowed_packet) + \
|
||
|
utils.int2store(charset) + \
|
||
|
b'\x00' * 22
|
||
|
|
||
|
def make_command(self, command, argument=None):
|
||
|
"""Make a MySQL packet containing a command"""
|
||
|
data = utils.int1store(command)
|
||
|
if argument is not None:
|
||
|
data += argument
|
||
|
return data
|
||
|
|
||
|
def make_stmt_fetch(self, statement_id, rows=1):
|
||
|
"""Make a MySQL packet with Fetch Statement command"""
|
||
|
return utils.int4store(statement_id) + utils.int4store(rows)
|
||
|
|
||
|
def make_change_user(self, handshake, username=None, password=None,
|
||
|
database=None, charset=45, client_flags=0,
|
||
|
ssl_enabled=False, auth_plugin=None):
|
||
|
"""Make a MySQL packet with the Change User command"""
|
||
|
|
||
|
try:
|
||
|
auth_data = handshake['auth_data']
|
||
|
auth_plugin = auth_plugin or handshake['auth_plugin']
|
||
|
except (TypeError, KeyError) as exc:
|
||
|
raise errors.ProgrammingError(
|
||
|
"Handshake misses authentication info ({0})".format(exc))
|
||
|
|
||
|
if not username:
|
||
|
username = b''
|
||
|
try:
|
||
|
username_bytes = username.encode('utf8') # pylint: disable=E1103
|
||
|
except AttributeError:
|
||
|
# Username is already bytes
|
||
|
username_bytes = username
|
||
|
packet = struct.pack('<B{usrlen}sx'.format(usrlen=len(username_bytes)),
|
||
|
ServerCmd.CHANGE_USER, username_bytes)
|
||
|
|
||
|
packet += self._auth_response(client_flags, username, password,
|
||
|
database,
|
||
|
auth_plugin,
|
||
|
auth_data, ssl_enabled)
|
||
|
|
||
|
packet += self._connect_with_db(client_flags, database)
|
||
|
|
||
|
packet += struct.pack('<H', charset)
|
||
|
|
||
|
if client_flags & ClientFlag.PLUGIN_AUTH:
|
||
|
packet += auth_plugin.encode('utf8') + b'\x00'
|
||
|
|
||
|
return packet
|
||
|
|
||
|
def parse_handshake(self, packet):
|
||
|
"""Parse a MySQL Handshake-packet"""
|
||
|
res = {}
|
||
|
res['protocol'] = struct_unpack('<xxxxB', packet[0:5])[0]
|
||
|
if res["protocol"] != PROTOCOL_VERSION:
|
||
|
raise DatabaseError("Protocol mismatch; server version = {}, "
|
||
|
"client version = {}".format(res["protocol"],
|
||
|
PROTOCOL_VERSION))
|
||
|
(packet, res['server_version_original']) = utils.read_string(
|
||
|
packet[5:], end=b'\x00')
|
||
|
|
||
|
(res['server_threadid'],
|
||
|
auth_data1,
|
||
|
capabilities1,
|
||
|
res['charset'],
|
||
|
res['server_status'],
|
||
|
capabilities2,
|
||
|
auth_data_length
|
||
|
) = struct_unpack('<I8sx2sBH2sBxxxxxxxxxx', packet[0:31])
|
||
|
res['server_version_original'] = res['server_version_original'].decode()
|
||
|
|
||
|
packet = packet[31:]
|
||
|
|
||
|
capabilities = utils.intread(capabilities1 + capabilities2)
|
||
|
auth_data2 = b''
|
||
|
if capabilities & ClientFlag.SECURE_CONNECTION:
|
||
|
size = min(13, auth_data_length - 8) if auth_data_length else 13
|
||
|
auth_data2 = packet[0:size]
|
||
|
packet = packet[size:]
|
||
|
if auth_data2[-1] == 0:
|
||
|
auth_data2 = auth_data2[:-1]
|
||
|
|
||
|
if capabilities & ClientFlag.PLUGIN_AUTH:
|
||
|
if (b'\x00' not in packet
|
||
|
and res['server_version_original'].startswith("5.5.8")):
|
||
|
# MySQL server 5.5.8 has a bug where end byte is not send
|
||
|
(packet, res['auth_plugin']) = (b'', packet)
|
||
|
else:
|
||
|
(packet, res['auth_plugin']) = utils.read_string(
|
||
|
packet, end=b'\x00')
|
||
|
res['auth_plugin'] = res['auth_plugin'].decode('utf-8')
|
||
|
else:
|
||
|
res['auth_plugin'] = 'mysql_native_password'
|
||
|
|
||
|
res['auth_data'] = auth_data1 + auth_data2
|
||
|
res['capabilities'] = capabilities
|
||
|
return res
|
||
|
|
||
|
def parse_ok(self, packet):
|
||
|
"""Parse a MySQL OK-packet"""
|
||
|
if not packet[4] == 0:
|
||
|
raise errors.InterfaceError("Failed parsing OK packet (invalid).")
|
||
|
|
||
|
ok_packet = {}
|
||
|
try:
|
||
|
ok_packet['field_count'] = struct_unpack('<xxxxB', packet[0:5])[0]
|
||
|
(packet, ok_packet['affected_rows']) = utils.read_lc_int(packet[5:])
|
||
|
(packet, ok_packet['insert_id']) = utils.read_lc_int(packet)
|
||
|
(ok_packet['status_flag'],
|
||
|
ok_packet['warning_count']) = struct_unpack('<HH', packet[0:4])
|
||
|
packet = packet[4:]
|
||
|
if packet:
|
||
|
(packet, ok_packet['info_msg']) = utils.read_lc_string(packet)
|
||
|
ok_packet['info_msg'] = ok_packet['info_msg'].decode('utf-8')
|
||
|
except ValueError:
|
||
|
raise errors.InterfaceError("Failed parsing OK packet.")
|
||
|
return ok_packet
|
||
|
|
||
|
def parse_column_count(self, packet):
|
||
|
"""Parse a MySQL packet with the number of columns in result set"""
|
||
|
try:
|
||
|
count = utils.read_lc_int(packet[4:])[1]
|
||
|
return count
|
||
|
except (struct.error, ValueError):
|
||
|
raise errors.InterfaceError("Failed parsing column count")
|
||
|
|
||
|
def parse_column(self, packet, charset='utf-8'):
|
||
|
"""Parse a MySQL column-packet"""
|
||
|
(packet, _) = utils.read_lc_string(packet[4:]) # catalog
|
||
|
(packet, _) = utils.read_lc_string(packet) # db
|
||
|
(packet, _) = utils.read_lc_string(packet) # table
|
||
|
(packet, _) = utils.read_lc_string(packet) # org_table
|
||
|
(packet, name) = utils.read_lc_string(packet) # name
|
||
|
(packet, _) = utils.read_lc_string(packet) # org_name
|
||
|
|
||
|
try:
|
||
|
(_, _, field_type,
|
||
|
flags, _) = struct_unpack('<xHIBHBxx', packet)
|
||
|
except struct.error:
|
||
|
raise errors.InterfaceError("Failed parsing column information")
|
||
|
|
||
|
return (
|
||
|
name.decode(charset),
|
||
|
field_type,
|
||
|
None, # display_size
|
||
|
None, # internal_size
|
||
|
None, # precision
|
||
|
None, # scale
|
||
|
~flags & FieldFlag.NOT_NULL, # null_ok
|
||
|
flags, # MySQL specific
|
||
|
)
|
||
|
|
||
|
def parse_eof(self, packet):
|
||
|
"""Parse a MySQL EOF-packet"""
|
||
|
if packet[4] == 0:
|
||
|
# EOF packet deprecation
|
||
|
return self.parse_ok(packet)
|
||
|
|
||
|
err_msg = "Failed parsing EOF packet."
|
||
|
res = {}
|
||
|
try:
|
||
|
unpacked = struct_unpack('<xxxBBHH', packet)
|
||
|
except struct.error:
|
||
|
raise errors.InterfaceError(err_msg)
|
||
|
|
||
|
if not (unpacked[1] == 254 and len(packet) <= 9):
|
||
|
raise errors.InterfaceError(err_msg)
|
||
|
|
||
|
res['warning_count'] = unpacked[2]
|
||
|
res['status_flag'] = unpacked[3]
|
||
|
return res
|
||
|
|
||
|
def parse_statistics(self, packet, with_header=True):
|
||
|
"""Parse the statistics packet"""
|
||
|
errmsg = "Failed getting COM_STATISTICS information"
|
||
|
res = {}
|
||
|
# Information is separated by 2 spaces
|
||
|
if with_header:
|
||
|
pairs = packet[4:].split(b'\x20\x20')
|
||
|
else:
|
||
|
pairs = packet.split(b'\x20\x20')
|
||
|
for pair in pairs:
|
||
|
try:
|
||
|
(lbl, val) = [v.strip() for v in pair.split(b':', 2)]
|
||
|
except:
|
||
|
raise errors.InterfaceError(errmsg)
|
||
|
|
||
|
# It's either an integer or a decimal
|
||
|
lbl = lbl.decode('utf-8')
|
||
|
try:
|
||
|
res[lbl] = int(val)
|
||
|
except:
|
||
|
try:
|
||
|
res[lbl] = Decimal(val.decode('utf-8'))
|
||
|
except:
|
||
|
raise errors.InterfaceError(
|
||
|
"{0} ({1}:{2}).".format(errmsg, lbl, val))
|
||
|
return res
|
||
|
|
||
|
def read_text_result(self, sock, version, count=1):
|
||
|
"""Read MySQL text result
|
||
|
|
||
|
Reads all or given number of rows from the socket.
|
||
|
|
||
|
Returns a tuple with 2 elements: a list with all rows and
|
||
|
the EOF packet.
|
||
|
"""
|
||
|
rows = []
|
||
|
eof = None
|
||
|
rowdata = None
|
||
|
i = 0
|
||
|
while True:
|
||
|
if eof or i == count:
|
||
|
break
|
||
|
packet = sock.recv()
|
||
|
if packet.startswith(b'\xff\xff\xff'):
|
||
|
datas = [packet[4:]]
|
||
|
packet = sock.recv()
|
||
|
while packet.startswith(b'\xff\xff\xff'):
|
||
|
datas.append(packet[4:])
|
||
|
packet = sock.recv()
|
||
|
datas.append(packet[4:])
|
||
|
rowdata = utils.read_lc_string_list(bytearray(b'').join(datas))
|
||
|
elif packet[4] == 254 and packet[0] < 7:
|
||
|
eof = self.parse_eof(packet)
|
||
|
rowdata = None
|
||
|
else:
|
||
|
eof = None
|
||
|
rowdata = utils.read_lc_string_list(packet[4:])
|
||
|
if eof is None and rowdata is not None:
|
||
|
rows.append(rowdata)
|
||
|
elif eof is None and rowdata is None:
|
||
|
raise get_exception(packet)
|
||
|
i += 1
|
||
|
return rows, eof
|
||
|
|
||
|
def _parse_binary_integer(self, packet, field):
|
||
|
"""Parse an integer from a binary packet"""
|
||
|
if field[1] == FieldType.TINY:
|
||
|
format_ = '<b'
|
||
|
length = 1
|
||
|
elif field[1] == FieldType.SHORT:
|
||
|
format_ = '<h'
|
||
|
length = 2
|
||
|
elif field[1] in (FieldType.INT24, FieldType.LONG):
|
||
|
format_ = '<i'
|
||
|
length = 4
|
||
|
elif field[1] == FieldType.LONGLONG:
|
||
|
format_ = '<q'
|
||
|
length = 8
|
||
|
|
||
|
if field[7] & FieldFlag.UNSIGNED:
|
||
|
format_ = format_.upper()
|
||
|
|
||
|
return (packet[length:], struct_unpack(format_, packet[0:length])[0])
|
||
|
|
||
|
def _parse_binary_float(self, packet, field):
|
||
|
"""Parse a float/double from a binary packet"""
|
||
|
if field[1] == FieldType.DOUBLE:
|
||
|
length = 8
|
||
|
format_ = '<d'
|
||
|
else:
|
||
|
length = 4
|
||
|
format_ = '<f'
|
||
|
|
||
|
return (packet[length:], struct_unpack(format_, packet[0:length])[0])
|
||
|
|
||
|
def _parse_binary_timestamp(self, packet, field):
|
||
|
"""Parse a timestamp from a binary packet"""
|
||
|
length = packet[0]
|
||
|
value = None
|
||
|
if length == 4:
|
||
|
value = datetime.date(
|
||
|
year=struct_unpack('<H', packet[1:3])[0],
|
||
|
month=packet[3],
|
||
|
day=packet[4])
|
||
|
elif length >= 7:
|
||
|
mcs = 0
|
||
|
if length == 11:
|
||
|
mcs = struct_unpack('<I', packet[8:length + 1])[0]
|
||
|
value = datetime.datetime(
|
||
|
year=struct_unpack('<H', packet[1:3])[0],
|
||
|
month=packet[3],
|
||
|
day=packet[4],
|
||
|
hour=packet[5],
|
||
|
minute=packet[6],
|
||
|
second=packet[7],
|
||
|
microsecond=mcs)
|
||
|
|
||
|
return (packet[length + 1:], value)
|
||
|
|
||
|
def _parse_binary_time(self, packet, field):
|
||
|
"""Parse a time value from a binary packet"""
|
||
|
length = packet[0]
|
||
|
data = packet[1:length + 1]
|
||
|
mcs = 0
|
||
|
if length > 8:
|
||
|
mcs = struct_unpack('<I', data[8:])[0]
|
||
|
days = struct_unpack('<I', data[1:5])[0]
|
||
|
if data[0] == 1:
|
||
|
days *= -1
|
||
|
tmp = datetime.timedelta(days=days,
|
||
|
seconds=data[7],
|
||
|
microseconds=mcs,
|
||
|
minutes=data[6],
|
||
|
hours=data[5])
|
||
|
|
||
|
return (packet[length + 1:], tmp)
|
||
|
|
||
|
def _parse_binary_values(self, fields, packet, charset='utf-8'):
|
||
|
"""Parse values from a binary result packet"""
|
||
|
null_bitmap_length = (len(fields) + 7 + 2) // 8
|
||
|
null_bitmap = [int(i) for i in packet[0:null_bitmap_length]]
|
||
|
packet = packet[null_bitmap_length:]
|
||
|
|
||
|
values = []
|
||
|
for pos, field in enumerate(fields):
|
||
|
if null_bitmap[int((pos+2)/8)] & (1 << (pos + 2) % 8):
|
||
|
values.append(None)
|
||
|
continue
|
||
|
elif field[1] in (FieldType.TINY, FieldType.SHORT,
|
||
|
FieldType.INT24,
|
||
|
FieldType.LONG, FieldType.LONGLONG):
|
||
|
(packet, value) = self._parse_binary_integer(packet, field)
|
||
|
values.append(value)
|
||
|
elif field[1] in (FieldType.DOUBLE, FieldType.FLOAT):
|
||
|
(packet, value) = self._parse_binary_float(packet, field)
|
||
|
values.append(value)
|
||
|
elif field[1] in (FieldType.DATETIME, FieldType.DATE,
|
||
|
FieldType.TIMESTAMP):
|
||
|
(packet, value) = self._parse_binary_timestamp(packet, field)
|
||
|
values.append(value)
|
||
|
elif field[1] == FieldType.TIME:
|
||
|
(packet, value) = self._parse_binary_time(packet, field)
|
||
|
values.append(value)
|
||
|
else:
|
||
|
(packet, value) = utils.read_lc_string(packet)
|
||
|
values.append(value.decode(charset))
|
||
|
|
||
|
return tuple(values)
|
||
|
|
||
|
def read_binary_result(self, sock, columns, count=1, charset='utf-8'):
|
||
|
"""Read MySQL binary protocol result
|
||
|
|
||
|
Reads all or given number of binary resultset rows from the socket.
|
||
|
"""
|
||
|
rows = []
|
||
|
eof = None
|
||
|
values = None
|
||
|
i = 0
|
||
|
while True:
|
||
|
if eof is not None:
|
||
|
break
|
||
|
if i == count:
|
||
|
break
|
||
|
packet = sock.recv()
|
||
|
if packet[4] == 254:
|
||
|
eof = self.parse_eof(packet)
|
||
|
values = None
|
||
|
elif packet[4] == 0:
|
||
|
eof = None
|
||
|
values = self._parse_binary_values(columns, packet[5:], charset)
|
||
|
if eof is None and values is not None:
|
||
|
rows.append(values)
|
||
|
elif eof is None and values is None:
|
||
|
raise get_exception(packet)
|
||
|
i += 1
|
||
|
return (rows, eof)
|
||
|
|
||
|
def parse_binary_prepare_ok(self, packet):
|
||
|
"""Parse a MySQL Binary Protocol OK packet"""
|
||
|
if not packet[4] == 0:
|
||
|
raise errors.InterfaceError("Failed parsing Binary OK packet")
|
||
|
|
||
|
ok_pkt = {}
|
||
|
try:
|
||
|
(packet, ok_pkt['statement_id']) = utils.read_int(packet[5:], 4)
|
||
|
(packet, ok_pkt['num_columns']) = utils.read_int(packet, 2)
|
||
|
(packet, ok_pkt['num_params']) = utils.read_int(packet, 2)
|
||
|
packet = packet[1:] # Filler 1 * \x00
|
||
|
(packet, ok_pkt['warning_count']) = utils.read_int(packet, 2)
|
||
|
except ValueError:
|
||
|
raise errors.InterfaceError("Failed parsing Binary OK packet")
|
||
|
|
||
|
return ok_pkt
|
||
|
|
||
|
def _prepare_binary_integer(self, value):
|
||
|
"""Prepare an integer for the MySQL binary protocol"""
|
||
|
field_type = None
|
||
|
flags = 0
|
||
|
if value < 0:
|
||
|
if value >= -128:
|
||
|
format_ = '<b'
|
||
|
field_type = FieldType.TINY
|
||
|
elif value >= -32768:
|
||
|
format_ = '<h'
|
||
|
field_type = FieldType.SHORT
|
||
|
elif value >= -2147483648:
|
||
|
format_ = '<i'
|
||
|
field_type = FieldType.LONG
|
||
|
else:
|
||
|
format_ = '<q'
|
||
|
field_type = FieldType.LONGLONG
|
||
|
else:
|
||
|
flags = 128
|
||
|
if value <= 255:
|
||
|
format_ = '<B'
|
||
|
field_type = FieldType.TINY
|
||
|
elif value <= 65535:
|
||
|
format_ = '<H'
|
||
|
field_type = FieldType.SHORT
|
||
|
elif value <= 4294967295:
|
||
|
format_ = '<I'
|
||
|
field_type = FieldType.LONG
|
||
|
else:
|
||
|
field_type = FieldType.LONGLONG
|
||
|
format_ = '<Q'
|
||
|
return (struct.pack(format_, value), field_type, flags)
|
||
|
|
||
|
def _prepare_binary_timestamp(self, value):
|
||
|
"""Prepare a timestamp object for the MySQL binary protocol
|
||
|
|
||
|
This method prepares a timestamp of type datetime.datetime or
|
||
|
datetime.date for sending over the MySQL binary protocol.
|
||
|
A tuple is returned with the prepared value and field type
|
||
|
as elements.
|
||
|
|
||
|
Raises ValueError when the argument value is of invalid type.
|
||
|
|
||
|
Returns a tuple.
|
||
|
"""
|
||
|
if isinstance(value, datetime.datetime):
|
||
|
field_type = FieldType.DATETIME
|
||
|
elif isinstance(value, datetime.date):
|
||
|
field_type = FieldType.DATE
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Argument must a datetime.datetime or datetime.date")
|
||
|
|
||
|
packed = (utils.int2store(value.year) +
|
||
|
utils.int1store(value.month) +
|
||
|
utils.int1store(value.day))
|
||
|
|
||
|
if isinstance(value, datetime.datetime):
|
||
|
packed = (packed + utils.int1store(value.hour) +
|
||
|
utils.int1store(value.minute) +
|
||
|
utils.int1store(value.second))
|
||
|
if value.microsecond > 0:
|
||
|
packed += utils.int4store(value.microsecond)
|
||
|
|
||
|
packed = utils.int1store(len(packed)) + packed
|
||
|
return (packed, field_type)
|
||
|
|
||
|
def _prepare_binary_time(self, value):
|
||
|
"""Prepare a time object for the MySQL binary protocol
|
||
|
|
||
|
This method prepares a time object of type datetime.timedelta or
|
||
|
datetime.time for sending over the MySQL binary protocol.
|
||
|
A tuple is returned with the prepared value and field type
|
||
|
as elements.
|
||
|
|
||
|
Raises ValueError when the argument value is of invalid type.
|
||
|
|
||
|
Returns a tuple.
|
||
|
"""
|
||
|
if not isinstance(value, (datetime.timedelta, datetime.time)):
|
||
|
raise ValueError(
|
||
|
"Argument must a datetime.timedelta or datetime.time")
|
||
|
|
||
|
field_type = FieldType.TIME
|
||
|
negative = 0
|
||
|
mcs = None
|
||
|
packed = b''
|
||
|
|
||
|
if isinstance(value, datetime.timedelta):
|
||
|
if value.days < 0:
|
||
|
negative = 1
|
||
|
(hours, remainder) = divmod(value.seconds, 3600)
|
||
|
(mins, secs) = divmod(remainder, 60)
|
||
|
packed += (utils.int4store(abs(value.days)) +
|
||
|
utils.int1store(hours) +
|
||
|
utils.int1store(mins) +
|
||
|
utils.int1store(secs))
|
||
|
mcs = value.microseconds
|
||
|
else:
|
||
|
packed += (utils.int4store(0) +
|
||
|
utils.int1store(value.hour) +
|
||
|
utils.int1store(value.minute) +
|
||
|
utils.int1store(value.second))
|
||
|
mcs = value.microsecond
|
||
|
if mcs:
|
||
|
packed += utils.int4store(mcs)
|
||
|
|
||
|
packed = utils.int1store(negative) + packed
|
||
|
packed = utils.int1store(len(packed)) + packed
|
||
|
|
||
|
return (packed, field_type)
|
||
|
|
||
|
def _prepare_stmt_send_long_data(self, statement, param, data):
|
||
|
"""Prepare long data for prepared statements
|
||
|
|
||
|
Returns a string.
|
||
|
"""
|
||
|
packet = (
|
||
|
utils.int4store(statement) +
|
||
|
utils.int2store(param) +
|
||
|
data)
|
||
|
return packet
|
||
|
|
||
|
def make_stmt_execute(self, statement_id, data=(), parameters=(),
|
||
|
flags=0, long_data_used=None, charset='utf8'):
|
||
|
"""Make a MySQL packet with the Statement Execute command"""
|
||
|
iteration_count = 1
|
||
|
null_bitmap = [0] * ((len(data) + 7) // 8)
|
||
|
values = []
|
||
|
types = []
|
||
|
packed = b''
|
||
|
if charset == 'utf8mb4':
|
||
|
charset = 'utf8'
|
||
|
if long_data_used is None:
|
||
|
long_data_used = {}
|
||
|
if parameters and data:
|
||
|
if len(data) != len(parameters):
|
||
|
raise errors.InterfaceError(
|
||
|
"Failed executing prepared statement: data values does not"
|
||
|
" match number of parameters")
|
||
|
for pos, _ in enumerate(parameters):
|
||
|
value = data[pos]
|
||
|
flags = 0
|
||
|
if value is None:
|
||
|
null_bitmap[(pos // 8)] |= 1 << (pos % 8)
|
||
|
types.append(utils.int1store(FieldType.NULL) +
|
||
|
utils.int1store(flags))
|
||
|
continue
|
||
|
elif pos in long_data_used:
|
||
|
if long_data_used[pos][0]:
|
||
|
# We suppose binary data
|
||
|
field_type = FieldType.BLOB
|
||
|
else:
|
||
|
# We suppose text data
|
||
|
field_type = FieldType.STRING
|
||
|
elif isinstance(value, int):
|
||
|
(packed, field_type,
|
||
|
flags) = self._prepare_binary_integer(value)
|
||
|
values.append(packed)
|
||
|
elif isinstance(value, str):
|
||
|
if PY2:
|
||
|
values.append(utils.lc_int(len(value)) +
|
||
|
value)
|
||
|
else:
|
||
|
value = value.encode(charset)
|
||
|
values.append(
|
||
|
utils.lc_int(len(value)) + value)
|
||
|
field_type = FieldType.VARCHAR
|
||
|
elif isinstance(value, bytes):
|
||
|
values.append(utils.lc_int(len(value)) + value)
|
||
|
field_type = FieldType.BLOB
|
||
|
elif PY2 and \
|
||
|
isinstance(value, unicode): # pylint: disable=E0602
|
||
|
value = value.encode(charset)
|
||
|
values.append(utils.lc_int(len(value)) + value)
|
||
|
field_type = FieldType.VARCHAR
|
||
|
elif isinstance(value, Decimal):
|
||
|
values.append(
|
||
|
utils.lc_int(len(str(value).encode(
|
||
|
charset))) + str(value).encode(charset))
|
||
|
field_type = FieldType.DECIMAL
|
||
|
elif isinstance(value, float):
|
||
|
values.append(struct.pack('<d', value))
|
||
|
field_type = FieldType.DOUBLE
|
||
|
elif isinstance(value, (datetime.datetime, datetime.date)):
|
||
|
(packed, field_type) = self._prepare_binary_timestamp(
|
||
|
value)
|
||
|
values.append(packed)
|
||
|
elif isinstance(value, (datetime.timedelta, datetime.time)):
|
||
|
(packed, field_type) = self._prepare_binary_time(value)
|
||
|
values.append(packed)
|
||
|
else:
|
||
|
raise errors.ProgrammingError(
|
||
|
"MySQL binary protocol can not handle "
|
||
|
"'{classname}' objects".format(
|
||
|
classname=value.__class__.__name__))
|
||
|
types.append(utils.int1store(field_type) +
|
||
|
utils.int1store(flags))
|
||
|
|
||
|
packet = (
|
||
|
utils.int4store(statement_id) +
|
||
|
utils.int1store(flags) +
|
||
|
utils.int4store(iteration_count) +
|
||
|
b''.join([struct.pack('B', bit) for bit in null_bitmap]) +
|
||
|
utils.int1store(1)
|
||
|
)
|
||
|
|
||
|
for a_type in types:
|
||
|
packet += a_type
|
||
|
|
||
|
for a_value in values:
|
||
|
packet += a_value
|
||
|
|
||
|
return packet
|
||
|
|
||
|
def parse_auth_switch_request(self, packet):
|
||
|
"""Parse a MySQL AuthSwitchRequest-packet"""
|
||
|
if not packet[4] == 254:
|
||
|
raise errors.InterfaceError(
|
||
|
"Failed parsing AuthSwitchRequest packet")
|
||
|
|
||
|
(packet, plugin_name) = utils.read_string(packet[5:], end=b'\x00')
|
||
|
if packet and packet[-1] == 0:
|
||
|
packet = packet[:-1]
|
||
|
|
||
|
return plugin_name.decode('utf8'), packet
|
||
|
|
||
|
def parse_auth_more_data(self, packet):
|
||
|
"""Parse a MySQL AuthMoreData-packet"""
|
||
|
if not packet[4] == 1:
|
||
|
raise errors.InterfaceError(
|
||
|
"Failed parsing AuthMoreData packet")
|
||
|
|
||
|
return packet[5:]
|