mirror of https://github.com/sgoudham/Enso-Bot.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
347 lines
12 KiB
Python
347 lines
12 KiB
Python
import asyncio
|
|
import mimetypes
|
|
import os
|
|
import pathlib
|
|
from functools import partial
|
|
from typing import ( # noqa
|
|
IO,
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Awaitable,
|
|
Callable,
|
|
List,
|
|
Optional,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from . import hdrs
|
|
from .abc import AbstractStreamWriter
|
|
from .base_protocol import BaseProtocol
|
|
from .helpers import set_exception, set_result
|
|
from .http_writer import StreamWriter
|
|
from .log import server_logger
|
|
from .typedefs import LooseHeaders
|
|
from .web_exceptions import (
|
|
HTTPNotModified,
|
|
HTTPPartialContent,
|
|
HTTPPreconditionFailed,
|
|
HTTPRequestRangeNotSatisfiable,
|
|
)
|
|
from .web_response import StreamResponse
|
|
|
|
__all__ = ('FileResponse',)
|
|
|
|
if TYPE_CHECKING: # pragma: no cover
|
|
from .web_request import BaseRequest # noqa
|
|
|
|
|
|
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
|
|
|
|
|
|
NOSENDFILE = bool(os.environ.get("AIOHTTP_NOSENDFILE"))
|
|
|
|
|
|
class SendfileStreamWriter(StreamWriter):
|
|
|
|
def __init__(self,
|
|
protocol: BaseProtocol,
|
|
loop: asyncio.AbstractEventLoop,
|
|
fobj: IO[Any],
|
|
count: int,
|
|
on_chunk_sent: _T_OnChunkSent=None) -> None:
|
|
super().__init__(protocol, loop, on_chunk_sent)
|
|
self._sendfile_buffer = [] # type: List[bytes]
|
|
self._fobj = fobj
|
|
self._count = count
|
|
self._offset = fobj.tell()
|
|
self._in_fd = fobj.fileno()
|
|
|
|
def _write(self, chunk: bytes) -> None:
|
|
# we overwrite StreamWriter._write, so nothing can be appended to
|
|
# _buffer, and nothing is written to the transport directly by the
|
|
# parent class
|
|
self.output_size += len(chunk)
|
|
self._sendfile_buffer.append(chunk)
|
|
|
|
def _sendfile_cb(self, fut: 'asyncio.Future[None]', out_fd: int) -> None:
|
|
if fut.cancelled():
|
|
return
|
|
try:
|
|
if self._do_sendfile(out_fd):
|
|
set_result(fut, None)
|
|
except Exception as exc:
|
|
set_exception(fut, exc)
|
|
|
|
def _do_sendfile(self, out_fd: int) -> bool:
|
|
try:
|
|
n = os.sendfile(out_fd,
|
|
self._in_fd,
|
|
self._offset,
|
|
self._count)
|
|
if n == 0: # in_fd EOF reached
|
|
n = self._count
|
|
except (BlockingIOError, InterruptedError):
|
|
n = 0
|
|
self.output_size += n
|
|
self._offset += n
|
|
self._count -= n
|
|
assert self._count >= 0
|
|
return self._count == 0
|
|
|
|
def _done_fut(self, out_fd: int, fut: 'asyncio.Future[None]') -> None:
|
|
self.loop.remove_writer(out_fd)
|
|
|
|
async def sendfile(self) -> None:
|
|
assert self.transport is not None
|
|
out_socket = self.transport.get_extra_info('socket').dup()
|
|
out_socket.setblocking(False)
|
|
out_fd = out_socket.fileno()
|
|
|
|
loop = self.loop
|
|
data = b''.join(self._sendfile_buffer)
|
|
try:
|
|
await loop.sock_sendall(out_socket, data)
|
|
if not self._do_sendfile(out_fd):
|
|
fut = loop.create_future()
|
|
fut.add_done_callback(partial(self._done_fut, out_fd))
|
|
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd)
|
|
await fut
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
server_logger.debug('Socket error')
|
|
self.transport.close()
|
|
finally:
|
|
out_socket.close()
|
|
|
|
await super().write_eof()
|
|
|
|
async def write_eof(self, chunk: bytes=b'') -> None:
|
|
pass
|
|
|
|
|
|
class FileResponse(StreamResponse):
|
|
"""A response object can be used to send files."""
|
|
|
|
def __init__(self, path: Union[str, pathlib.Path],
|
|
chunk_size: int=256*1024,
|
|
status: int=200,
|
|
reason: Optional[str]=None,
|
|
headers: Optional[LooseHeaders]=None) -> None:
|
|
super().__init__(status=status, reason=reason, headers=headers)
|
|
|
|
if isinstance(path, str):
|
|
path = pathlib.Path(path)
|
|
|
|
self._path = path
|
|
self._chunk_size = chunk_size
|
|
|
|
async def _sendfile_system(self, request: 'BaseRequest',
|
|
fobj: IO[Any],
|
|
count: int) -> AbstractStreamWriter:
|
|
# Write count bytes of fobj to resp using
|
|
# the os.sendfile system call.
|
|
#
|
|
# For details check
|
|
# https://github.com/KeepSafe/aiohttp/issues/1177
|
|
# See https://github.com/KeepSafe/aiohttp/issues/958 for details
|
|
#
|
|
# request should be an aiohttp.web.Request instance.
|
|
# fobj should be an open file object.
|
|
# count should be an integer > 0.
|
|
|
|
transport = request.transport
|
|
assert transport is not None
|
|
if (transport.get_extra_info("sslcontext") or
|
|
transport.get_extra_info("socket") is None or
|
|
self.compression):
|
|
writer = await self._sendfile_fallback(request, fobj, count)
|
|
else:
|
|
writer = SendfileStreamWriter(
|
|
request.protocol,
|
|
request._loop,
|
|
fobj,
|
|
count
|
|
)
|
|
request._payload_writer = writer
|
|
|
|
await super().prepare(request)
|
|
await writer.sendfile()
|
|
|
|
return writer
|
|
|
|
async def _sendfile_fallback(self, request: 'BaseRequest',
|
|
fobj: IO[Any],
|
|
count: int) -> AbstractStreamWriter:
|
|
# Mimic the _sendfile_system() method, but without using the
|
|
# os.sendfile() system call. This should be used on systems
|
|
# that don't support the os.sendfile().
|
|
|
|
# To keep memory usage low,fobj is transferred in chunks
|
|
# controlled by the constructor's chunk_size argument.
|
|
|
|
writer = await super().prepare(request)
|
|
assert writer is not None
|
|
|
|
chunk_size = self._chunk_size
|
|
loop = asyncio.get_event_loop()
|
|
|
|
chunk = await loop.run_in_executor(None, fobj.read, chunk_size)
|
|
while chunk:
|
|
await writer.write(chunk)
|
|
count = count - chunk_size
|
|
if count <= 0:
|
|
break
|
|
chunk = await loop.run_in_executor(
|
|
None, fobj.read, min(chunk_size, count)
|
|
)
|
|
|
|
await writer.drain()
|
|
return writer
|
|
|
|
if hasattr(os, "sendfile") and not NOSENDFILE: # pragma: no cover
|
|
_sendfile = _sendfile_system
|
|
else: # pragma: no cover
|
|
_sendfile = _sendfile_fallback
|
|
|
|
async def prepare(
|
|
self,
|
|
request: 'BaseRequest'
|
|
) -> Optional[AbstractStreamWriter]:
|
|
filepath = self._path
|
|
|
|
gzip = False
|
|
if 'gzip' in request.headers.get(hdrs.ACCEPT_ENCODING, ''):
|
|
gzip_path = filepath.with_name(filepath.name + '.gz')
|
|
|
|
if gzip_path.is_file():
|
|
filepath = gzip_path
|
|
gzip = True
|
|
|
|
loop = asyncio.get_event_loop()
|
|
st = await loop.run_in_executor(None, filepath.stat)
|
|
|
|
modsince = request.if_modified_since
|
|
if modsince is not None and st.st_mtime <= modsince.timestamp():
|
|
self.set_status(HTTPNotModified.status_code)
|
|
self._length_check = False
|
|
# Delete any Content-Length headers provided by user. HTTP 304
|
|
# should always have empty response body
|
|
return await super().prepare(request)
|
|
|
|
unmodsince = request.if_unmodified_since
|
|
if unmodsince is not None and st.st_mtime > unmodsince.timestamp():
|
|
self.set_status(HTTPPreconditionFailed.status_code)
|
|
return await super().prepare(request)
|
|
|
|
if hdrs.CONTENT_TYPE not in self.headers:
|
|
ct, encoding = mimetypes.guess_type(str(filepath))
|
|
if not ct:
|
|
ct = 'application/octet-stream'
|
|
should_set_ct = True
|
|
else:
|
|
encoding = 'gzip' if gzip else None
|
|
should_set_ct = False
|
|
|
|
status = self._status
|
|
file_size = st.st_size
|
|
count = file_size
|
|
|
|
start = None
|
|
|
|
ifrange = request.if_range
|
|
if ifrange is None or st.st_mtime <= ifrange.timestamp():
|
|
# If-Range header check:
|
|
# condition = cached date >= last modification date
|
|
# return 206 if True else 200.
|
|
# if False:
|
|
# Range header would not be processed, return 200
|
|
# if True but Range header missing
|
|
# return 200
|
|
try:
|
|
rng = request.http_range
|
|
start = rng.start
|
|
end = rng.stop
|
|
except ValueError:
|
|
# https://tools.ietf.org/html/rfc7233:
|
|
# A server generating a 416 (Range Not Satisfiable) response to
|
|
# a byte-range request SHOULD send a Content-Range header field
|
|
# with an unsatisfied-range value.
|
|
# The complete-length in a 416 response indicates the current
|
|
# length of the selected representation.
|
|
#
|
|
# Will do the same below. Many servers ignore this and do not
|
|
# send a Content-Range header with HTTP 416
|
|
self.headers[hdrs.CONTENT_RANGE] = 'bytes */{0}'.format(
|
|
file_size)
|
|
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
|
|
return await super().prepare(request)
|
|
|
|
# If a range request has been made, convert start, end slice
|
|
# notation into file pointer offset and count
|
|
if start is not None or end is not None:
|
|
if start < 0 and end is None: # return tail of file
|
|
start += file_size
|
|
if start < 0:
|
|
# if Range:bytes=-1000 in request header but file size
|
|
# is only 200, there would be trouble without this
|
|
start = 0
|
|
count = file_size - start
|
|
else:
|
|
# rfc7233:If the last-byte-pos value is
|
|
# absent, or if the value is greater than or equal to
|
|
# the current length of the representation data,
|
|
# the byte range is interpreted as the remainder
|
|
# of the representation (i.e., the server replaces the
|
|
# value of last-byte-pos with a value that is one less than
|
|
# the current length of the selected representation).
|
|
count = min(end if end is not None else file_size,
|
|
file_size) - start
|
|
|
|
if start >= file_size:
|
|
# HTTP 416 should be returned in this case.
|
|
#
|
|
# According to https://tools.ietf.org/html/rfc7233:
|
|
# If a valid byte-range-set includes at least one
|
|
# byte-range-spec with a first-byte-pos that is less than
|
|
# the current length of the representation, or at least one
|
|
# suffix-byte-range-spec with a non-zero suffix-length,
|
|
# then the byte-range-set is satisfiable. Otherwise, the
|
|
# byte-range-set is unsatisfiable.
|
|
self.headers[hdrs.CONTENT_RANGE] = 'bytes */{0}'.format(
|
|
file_size)
|
|
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
|
|
return await super().prepare(request)
|
|
|
|
status = HTTPPartialContent.status_code
|
|
# Even though you are sending the whole file, you should still
|
|
# return a HTTP 206 for a Range request.
|
|
self.set_status(status)
|
|
|
|
if should_set_ct:
|
|
self.content_type = ct # type: ignore
|
|
if encoding:
|
|
self.headers[hdrs.CONTENT_ENCODING] = encoding
|
|
if gzip:
|
|
self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
|
|
self.last_modified = st.st_mtime # type: ignore
|
|
self.content_length = count
|
|
|
|
self.headers[hdrs.ACCEPT_RANGES] = 'bytes'
|
|
|
|
real_start = cast(int, start)
|
|
|
|
if status == HTTPPartialContent.status_code:
|
|
self.headers[hdrs.CONTENT_RANGE] = 'bytes {0}-{1}/{2}'.format(
|
|
real_start, real_start + count - 1, file_size)
|
|
|
|
fobj = await loop.run_in_executor(None, filepath.open, 'rb')
|
|
if start: # be aware that start could be None or int=0 here.
|
|
await loop.run_in_executor(None, fobj.seek, start)
|
|
|
|
try:
|
|
return await self._sendfile(request, fobj, count)
|
|
finally:
|
|
await loop.run_in_executor(None, fobj.close)
|