Added more requirements

pull/1/head
sgoudham 5 years ago
parent a517aeeaf6
commit 73f38a2347

@ -58,7 +58,7 @@ class Waifus(commands.Cog):
async def toga(self, ctx):
try:
# path = PureWindowsPath(r'C:\Users\sgoud\PycharmProjects\EnsoBot\images\togaImages.txt')
with open('images/togaImages.txt') as file:
toga_array = file.readlines()

@ -0,0 +1,354 @@
Metadata-Version: 1.2
Name: aiohttp
Version: 1.0.5
Summary: http client/server for asyncio
Home-page: https://github.com/KeepSafe/aiohttp/
Author: Nikolay Kim
Author-email: fafhrd91@gmail.com
Maintainer: Andrew Svetlov
Maintainer-email: andrew.svetlov@gmail.com
License: Apache 2
Description: http client/server for asyncio
==============================
.. image:: https://raw.github.com/KeepSafe/aiohttp/master/docs/_static/aiohttp-icon-128x128.png
:height: 64px
:width: 64px
:alt: aiohttp logo
.. image:: https://travis-ci.org/KeepSafe/aiohttp.svg?branch=master
:target: https://travis-ci.org/KeepSafe/aiohttp
:align: right
.. image:: https://codecov.io/gh/KeepSafe/aiohttp/branch/master/graph/badge.svg
:target: https://codecov.io/gh/KeepSafe/aiohttp
.. image:: https://badge.fury.io/py/aiohttp.svg
:target: https://badge.fury.io/py/aiohttp
Features
--------
- Supports both client and server side of HTTP protocol.
- Supports both client and server Web-Sockets out-of-the-box.
- Web-server has middlewares and pluggable routing.
Getting started
---------------
Client
^^^^^^
To retrieve something from the web:
.. code-block:: python
import aiohttp
import asyncio
async def fetch(session, url):
with aiohttp.Timeout(10, loop=session.loop):
async with session.get(url) as response:
return await response.text()
async def main(loop):
async with aiohttp.ClientSession(loop=loop) as session:
html = await fetch(session, 'http://python.org')
print(html)
if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(main(loop))
Server
^^^^^^
This is simple usage example:
.. code-block:: python
from aiohttp import web
async def handle(request):
name = request.match_info.get('name', "Anonymous")
text = "Hello, " + name
return web.Response(text=text)
async def wshandler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
if msg.type == web.MsgType.text:
ws.send_str("Hello, {}".format(msg.data))
elif msg.type == web.MsgType.binary:
ws.send_bytes(msg.data)
elif msg.type == web.MsgType.close:
break
return ws
app = web.Application()
app.router.add_get('/echo', wshandler)
app.router.add_get('/', handle)
app.router.add_get('/{name}', handle)
web.run_app(app)
Note: examples are written for Python 3.5+ and utilize PEP-492 aka
async/await. If you are using Python 3.4 please replace ``await`` with
``yield from`` and ``async def`` with ``@coroutine`` e.g.::
async def coro(...):
ret = await f()
should be replaced by::
@asyncio.coroutine
def coro(...):
ret = yield from f()
Documentation
-------------
https://aiohttp.readthedocs.io/
Discussion list
---------------
*aio-libs* google group: https://groups.google.com/forum/#!forum/aio-libs
Requirements
------------
- Python >= 3.4.2
- chardet_
- multidict_
Optionally you may install the cChardet_ and aiodns_ libraries (highly
recommended for sake of speed).
.. _chardet: https://pypi.python.org/pypi/chardet
.. _aiodns: https://pypi.python.org/pypi/aiodns
.. _multidict: https://pypi.python.org/pypi/multidict
.. _cChardet: https://pypi.python.org/pypi/cchardet
License
-------
``aiohttp`` is offered under the Apache 2 license.
Source code
------------
The latest developer version is available in a github repository:
https://github.com/KeepSafe/aiohttp
Benchmarks
----------
If you are interested in by efficiency, AsyncIO community maintains a
list of benchmarks on the official wiki:
https://github.com/python/asyncio/wiki/Benchmarks
CHANGES
=======
1.0.5 (2016-10-11)
------------------
- Fix StreamReader._read_nowait to return all available
data up to the requested amount #1297
1.0.4 (2016-09-22)
------------------
- Fix FlowControlStreamReader.read_nowait so that it checks
whether the transport is paused #1206
1.0.2 (2016-09-22)
------------------
- Make CookieJar compatible with 32-bit systems #1188
- Add missing `WSMsgType` to `web_ws.__all__`, see #1200
- Fix `CookieJar` ctor when called with `loop=None` #1203
- Fix broken upper-casing in wsgi support #1197
1.0.1 (2016-09-16)
------------------
- Restore `aiohttp.web.MsgType` alias for `aiohttp.WSMsgType` for sake
of backward compatibility #1178
- Tune alabaster schema.
- Use `text/html` content type for displaying index pages by static
file handler.
- Fix `AssertionError` in static file handling #1177
- Fix access log formats `%O` and `%b` for static file handling
- Remove `debug` setting of GunicornWorker, use `app.debug`
to control its debug-mode instead
1.0.0 (2016-09-16)
-------------------
- Change default size for client session's connection pool from
unlimited to 20 #977
- Add IE support for cookie deletion. #994
- Remove deprecated `WebSocketResponse.wait_closed` method (BACKWARD
INCOMPATIBLE)
- Remove deprecated `force` parameter for `ClientResponse.close`
method (BACKWARD INCOMPATIBLE)
- Avoid using of mutable CIMultiDict kw param in make_mocked_request
#997
- Make WebSocketResponse.close a little bit faster by avoiding new
task creating just for timeout measurement
- Add `proxy` and `proxy_auth` params to `client.get()` and family,
deprecate `ProxyConnector` #998
- Add support for websocket send_json and receive_json, synchronize
server and client API for websockets #984
- Implement router shourtcuts for most useful HTTP methods, use
`app.router.add_get()`, `app.router.add_post()` etc. instead of
`app.router.add_route()` #986
- Support SSL connections for gunicorn worker #1003
- Move obsolete examples to legacy folder
- Switch to multidict 2.0 and title-cased strings #1015
- `{FOO}e` logger format is case-sensitive now
- Fix logger report for unix socket 8e8469b
- Rename aiohttp.websocket to aiohttp._ws_impl
- Rename aiohttp.MsgType tp aiohttp.WSMsgType
- Introduce aiohttp.WSMessage officially
- Rename Message -> WSMessage
- Remove deprecated decode param from resp.read(decode=True)
- Use 5min default client timeout #1028
- Relax HTTP method validation in UrlDispatcher #1037
- Pin minimal supported asyncio version to 3.4.2+ (`loop.is_close()`
should be present)
- Remove aiohttp.websocket module (BACKWARD INCOMPATIBLE)
Please use high-level client and server approaches
- Link header for 451 status code is mandatory
- Fix test_client fixture to allow multiple clients per test #1072
- make_mocked_request now accepts dict as headers #1073
- Add Python 3.5.2/3.6+ compatibility patch for async generator
protocol change #1082
- Improvement test_client can accept instance object #1083
- Simplify ServerHttpProtocol implementation #1060
- Add a flag for optional showing directory index for static file
handling #921
- Define `web.Application.on_startup()` signal handler #1103
- Drop ChunkedParser and LinesParser #1111
- Call `Application.startup` in GunicornWebWorker #1105
- Fix client handling hostnames with 63 bytes when a port is given in
the url #1044
- Implement proxy support for ClientSession.ws_connect #1025
- Return named tuple from WebSocketResponse.can_prepare #1016
- Fix access_log_format in `GunicornWebWorker` #1117
- Setup Content-Type to application/octet-stream by default #1124
- Deprecate debug parameter from app.make_handler(), use
`Application(debug=True)` instead #1121
- Remove fragment string in request path #846
- Use aiodns.DNSResolver.gethostbyname() if available #1136
- Fix static file sending on uvloop when sendfile is available #1093
- Make prettier urls if query is empty dict #1143
- Fix redirects for HEAD requests #1147
- Default value for `StreamReader.read_nowait` is -1 from now #1150
- `aiohttp.StreamReader` is not inherited from `asyncio.StreamReader` from now
(BACKWARD INCOMPATIBLE) #1150
- Streams documentation added #1150
- Add `multipart` coroutine method for web Request object #1067
- Publish ClientSession.loop property #1149
- Fix static file with spaces #1140
- Fix piling up asyncio loop by cookie expiration callbacks #1061
- Drop `Timeout` class for sake of `async_timeout` external library.
`aiohttp.Timeout` is an alias for `async_timeout.timeout`
- `use_dns_cache` parameter of `aiohttp.TCPConnector` is `True` by
default (BACKWARD INCOMPATIBLE) #1152
- `aiohttp.TCPConnector` uses asynchronous DNS resolver if available by
default (BACKWARD INCOMPATIBLE) #1152
- Conform to RFC3986 - do not include url fragments in client requests #1174
- Drop `ClientSession.cookies` (BACKWARD INCOMPATIBLE) #1173
- Refactor `AbstractCookieJar` public API (BACKWARD INCOMPATIBLE) #1173
- Fix clashing cookies with have the same name but belong to different
domains (BACKWARD INCOMPATIBLE) #1125
- Support binary Content-Transfer-Encoding #1169
Platform: UNKNOWN
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Topic :: Internet :: WWW/HTTP

@ -0,0 +1,162 @@
CHANGES.rst
CONTRIBUTORS.txt
LICENSE.txt
MANIFEST.in
Makefile
README.rst
setup.cfg
setup.py
aiohttp/__init__.py
aiohttp/_websocket.c
aiohttp/_websocket.pyx
aiohttp/_ws_impl.py
aiohttp/abc.py
aiohttp/client.py
aiohttp/client_reqrep.py
aiohttp/client_ws.py
aiohttp/connector.py
aiohttp/cookiejar.py
aiohttp/errors.py
aiohttp/file_sender.py
aiohttp/hdrs.py
aiohttp/helpers.py
aiohttp/log.py
aiohttp/multipart.py
aiohttp/parsers.py
aiohttp/protocol.py
aiohttp/pytest_plugin.py
aiohttp/resolver.py
aiohttp/server.py
aiohttp/signals.py
aiohttp/streams.py
aiohttp/test_utils.py
aiohttp/web.py
aiohttp/web_exceptions.py
aiohttp/web_reqrep.py
aiohttp/web_urldispatcher.py
aiohttp/web_ws.py
aiohttp/worker.py
aiohttp/wsgi.py
aiohttp.egg-info/PKG-INFO
aiohttp.egg-info/SOURCES.txt
aiohttp.egg-info/dependency_links.txt
aiohttp.egg-info/requires.txt
aiohttp.egg-info/top_level.txt
docs/Makefile
docs/abc.rst
docs/aiohttp-icon.ico
docs/aiohttp-icon.svg
docs/api.rst
docs/changes.rst
docs/client.rst
docs/client_reference.rst
docs/conf.py
docs/contributing.rst
docs/faq.rst
docs/glossary.rst
docs/gunicorn.rst
docs/index.rst
docs/logging.rst
docs/make.bat
docs/multipart.rst
docs/new_router.rst
docs/server.rst
docs/spelling_wordlist.txt
docs/streams.rst
docs/testing.rst
docs/tutorial.rst
docs/web.rst
docs/web_reference.rst
docs/_static/aiohttp-icon-128x128.png
docs/_static/aiohttp-icon-32x32.png
docs/_static/aiohttp-icon-64x64.png
docs/_static/aiohttp-icon-96x96.png
examples/background_tasks.py
examples/basic_srv.py
examples/cli_app.py
examples/client_auth.py
examples/client_json.py
examples/client_ws.py
examples/curl.py
examples/fake_server.py
examples/server.crt
examples/server.csr
examples/server.key
examples/static_files.py
examples/web_classview1.py
examples/web_cookies.py
examples/web_rewrite_headers_middleware.py
examples/web_srv.py
examples/web_ws.py
examples/websocket.html
examples/legacy/crawl.py
examples/legacy/srv.py
examples/legacy/tcp_protocol_parser.py
tests/conftest.py
tests/data.unknown_mime_type
tests/hello.txt.gz
tests/sample.crt
tests/sample.crt.der
tests/sample.key
tests/software_development_in_picture.jpg
tests/test_classbasedview.py
tests/test_client_connection.py
tests/test_client_functional.py
tests/test_client_functional_oldstyle.py
tests/test_client_request.py
tests/test_client_response.py
tests/test_client_session.py
tests/test_client_ws.py
tests/test_client_ws_functional.py
tests/test_connector.py
tests/test_cookiejar.py
tests/test_errors.py
tests/test_flowcontrol_streams.py
tests/test_helpers.py
tests/test_http_parser.py
tests/test_multipart.py
tests/test_parser_buffer.py
tests/test_protocol.py
tests/test_proxy.py
tests/test_pytest_plugin.py
tests/test_resolver.py
tests/test_run_app.py
tests/test_server.py
tests/test_signals.py
tests/test_stream_parser.py
tests/test_stream_protocol.py
tests/test_stream_writer.py
tests/test_streams.py
tests/test_test_utils.py
tests/test_urldispatch.py
tests/test_web_application.py
tests/test_web_cli.py
tests/test_web_exceptions.py
tests/test_web_functional.py
tests/test_web_middleware.py
tests/test_web_request.py
tests/test_web_request_handler.py
tests/test_web_response.py
tests/test_web_sendfile.py
tests/test_web_sendfile_functional.py
tests/test_web_urldispatcher.py
tests/test_web_websocket.py
tests/test_web_websocket_functional.py
tests/test_web_websocket_functional_oldstyle.py
tests/test_websocket_handshake.py
tests/test_websocket_parser.py
tests/test_websocket_writer.py
tests/test_worker.py
tests/test_wsgi.py
tests/autobahn/client.py
tests/autobahn/fuzzingclient.json
tests/autobahn/fuzzingserver.json
tests/autobahn/server.py
tests/test_py35/test_cbv35.py
tests/test_py35/test_client.py
tests/test_py35/test_client_websocket_35.py
tests/test_py35/test_multipart_35.py
tests/test_py35/test_resp.py
tests/test_py35/test_streams_35.py
tests/test_py35/test_test_utils_35.py
tests/test_py35/test_web_websocket_35.py

@ -0,0 +1,65 @@
..\aiohttp\__init__.py
..\aiohttp\__pycache__\__init__.cpython-36.pyc
..\aiohttp\__pycache__\_ws_impl.cpython-36.pyc
..\aiohttp\__pycache__\abc.cpython-36.pyc
..\aiohttp\__pycache__\client.cpython-36.pyc
..\aiohttp\__pycache__\client_reqrep.cpython-36.pyc
..\aiohttp\__pycache__\client_ws.cpython-36.pyc
..\aiohttp\__pycache__\connector.cpython-36.pyc
..\aiohttp\__pycache__\cookiejar.cpython-36.pyc
..\aiohttp\__pycache__\errors.cpython-36.pyc
..\aiohttp\__pycache__\file_sender.cpython-36.pyc
..\aiohttp\__pycache__\hdrs.cpython-36.pyc
..\aiohttp\__pycache__\helpers.cpython-36.pyc
..\aiohttp\__pycache__\log.cpython-36.pyc
..\aiohttp\__pycache__\multipart.cpython-36.pyc
..\aiohttp\__pycache__\parsers.cpython-36.pyc
..\aiohttp\__pycache__\protocol.cpython-36.pyc
..\aiohttp\__pycache__\pytest_plugin.cpython-36.pyc
..\aiohttp\__pycache__\resolver.cpython-36.pyc
..\aiohttp\__pycache__\server.cpython-36.pyc
..\aiohttp\__pycache__\signals.cpython-36.pyc
..\aiohttp\__pycache__\streams.cpython-36.pyc
..\aiohttp\__pycache__\test_utils.cpython-36.pyc
..\aiohttp\__pycache__\web.cpython-36.pyc
..\aiohttp\__pycache__\web_exceptions.cpython-36.pyc
..\aiohttp\__pycache__\web_reqrep.cpython-36.pyc
..\aiohttp\__pycache__\web_urldispatcher.cpython-36.pyc
..\aiohttp\__pycache__\web_ws.cpython-36.pyc
..\aiohttp\__pycache__\worker.cpython-36.pyc
..\aiohttp\__pycache__\wsgi.cpython-36.pyc
..\aiohttp\_websocket.c
..\aiohttp\_websocket.pyx
..\aiohttp\_ws_impl.py
..\aiohttp\abc.py
..\aiohttp\client.py
..\aiohttp\client_reqrep.py
..\aiohttp\client_ws.py
..\aiohttp\connector.py
..\aiohttp\cookiejar.py
..\aiohttp\errors.py
..\aiohttp\file_sender.py
..\aiohttp\hdrs.py
..\aiohttp\helpers.py
..\aiohttp\log.py
..\aiohttp\multipart.py
..\aiohttp\parsers.py
..\aiohttp\protocol.py
..\aiohttp\pytest_plugin.py
..\aiohttp\resolver.py
..\aiohttp\server.py
..\aiohttp\signals.py
..\aiohttp\streams.py
..\aiohttp\test_utils.py
..\aiohttp\web.py
..\aiohttp\web_exceptions.py
..\aiohttp\web_reqrep.py
..\aiohttp\web_urldispatcher.py
..\aiohttp\web_ws.py
..\aiohttp\worker.py
..\aiohttp\wsgi.py
PKG-INFO
SOURCES.txt
dependency_links.txt
requires.txt
top_level.txt

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2013-2019 Nikolay Kim and Andrew Svetlov
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -1,652 +0,0 @@
Metadata-Version: 2.1
Name: aiohttp
Version: 3.6.2
Summary: Async http client/server framework (asyncio)
Home-page: https://github.com/aio-libs/aiohttp
Author: Nikolay Kim
Author-email: fafhrd91@gmail.com
Maintainer: Nikolay Kim <fafhrd91@gmail.com>, Andrew Svetlov <andrew.svetlov@gmail.com>
Maintainer-email: aio-libs@googlegroups.com
License: Apache 2
Project-URL: Chat: Gitter, https://gitter.im/aio-libs/Lobby
Project-URL: CI: AppVeyor, https://ci.appveyor.com/project/aio-libs/aiohttp
Project-URL: CI: Circle, https://circleci.com/gh/aio-libs/aiohttp
Project-URL: CI: Shippable, https://app.shippable.com/github/aio-libs/aiohttp
Project-URL: CI: Travis, https://travis-ci.com/aio-libs/aiohttp
Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/aiohttp
Project-URL: Docs: RTD, https://docs.aiohttp.org
Project-URL: GitHub: issues, https://github.com/aio-libs/aiohttp/issues
Project-URL: GitHub: repo, https://github.com/aio-libs/aiohttp
Platform: UNKNOWN
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Development Status :: 5 - Production/Stable
Classifier: Operating System :: POSIX
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Operating System :: Microsoft :: Windows
Classifier: Topic :: Internet :: WWW/HTTP
Classifier: Framework :: AsyncIO
Requires-Python: >=3.5.3
Requires-Dist: attrs (>=17.3.0)
Requires-Dist: chardet (<4.0,>=2.0)
Requires-Dist: multidict (<5.0,>=4.5)
Requires-Dist: async-timeout (<4.0,>=3.0)
Requires-Dist: yarl (<2.0,>=1.0)
Requires-Dist: idna-ssl (>=1.0) ; python_version < "3.7"
Requires-Dist: typing-extensions (>=3.6.5) ; python_version < "3.7"
Provides-Extra: speedups
Requires-Dist: aiodns ; extra == 'speedups'
Requires-Dist: brotlipy ; extra == 'speedups'
Requires-Dist: cchardet ; extra == 'speedups'
==================================
Async http client/server framework
==================================
.. image:: https://raw.githubusercontent.com/aio-libs/aiohttp/master/docs/_static/aiohttp-icon-128x128.png
:height: 64px
:width: 64px
:alt: aiohttp logo
|
.. image:: https://travis-ci.com/aio-libs/aiohttp.svg?branch=master
:target: https://travis-ci.com/aio-libs/aiohttp
:align: right
:alt: Travis status for master branch
.. image:: https://ci.appveyor.com/api/projects/status/tnddy9k6pphl8w7k/branch/master?svg=true
:target: https://ci.appveyor.com/project/aio-libs/aiohttp
:align: right
:alt: AppVeyor status for master branch
.. image:: https://codecov.io/gh/aio-libs/aiohttp/branch/master/graph/badge.svg
:target: https://codecov.io/gh/aio-libs/aiohttp
:alt: codecov.io status for master branch
.. image:: https://badge.fury.io/py/aiohttp.svg
:target: https://pypi.org/project/aiohttp
:alt: Latest PyPI package version
.. image:: https://readthedocs.org/projects/aiohttp/badge/?version=latest
:target: https://docs.aiohttp.org/
:alt: Latest Read The Docs
.. image:: https://badges.gitter.im/Join%20Chat.svg
:target: https://gitter.im/aio-libs/Lobby
:alt: Chat on Gitter
Key Features
============
- Supports both client and server side of HTTP protocol.
- Supports both client and server Web-Sockets out-of-the-box and avoids
Callback Hell.
- Provides Web-server with middlewares and pluggable routing.
Getting started
===============
Client
------
To get something from the web:
.. code-block:: python
import aiohttp
import asyncio
async def fetch(session, url):
async with session.get(url) as response:
return await response.text()
async def main():
async with aiohttp.ClientSession() as session:
html = await fetch(session, 'http://python.org')
print(html)
if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
Server
------
An example using a simple server:
.. code-block:: python
# examples/server_simple.py
from aiohttp import web
async def handle(request):
name = request.match_info.get('name', "Anonymous")
text = "Hello, " + name
return web.Response(text=text)
async def wshandle(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
if msg.type == web.WSMsgType.text:
await ws.send_str("Hello, {}".format(msg.data))
elif msg.type == web.WSMsgType.binary:
await ws.send_bytes(msg.data)
elif msg.type == web.WSMsgType.close:
break
return ws
app = web.Application()
app.add_routes([web.get('/', handle),
web.get('/echo', wshandle),
web.get('/{name}', handle)])
if __name__ == '__main__':
web.run_app(app)
Documentation
=============
https://aiohttp.readthedocs.io/
Demos
=====
https://github.com/aio-libs/aiohttp-demos
External links
==============
* `Third party libraries
<http://aiohttp.readthedocs.io/en/latest/third_party.html>`_
* `Built with aiohttp
<http://aiohttp.readthedocs.io/en/latest/built_with.html>`_
* `Powered by aiohttp
<http://aiohttp.readthedocs.io/en/latest/powered_by.html>`_
Feel free to make a Pull Request for adding your link to these pages!
Communication channels
======================
*aio-libs* google group: https://groups.google.com/forum/#!forum/aio-libs
Feel free to post your questions and ideas here.
*gitter chat* https://gitter.im/aio-libs/Lobby
We support `Stack Overflow
<https://stackoverflow.com/questions/tagged/aiohttp>`_.
Please add *aiohttp* tag to your question there.
Requirements
============
- Python >= 3.5.3
- async-timeout_
- attrs_
- chardet_
- multidict_
- yarl_
Optionally you may install the cChardet_ and aiodns_ libraries (highly
recommended for sake of speed).
.. _chardet: https://pypi.python.org/pypi/chardet
.. _aiodns: https://pypi.python.org/pypi/aiodns
.. _attrs: https://github.com/python-attrs/attrs
.. _multidict: https://pypi.python.org/pypi/multidict
.. _yarl: https://pypi.python.org/pypi/yarl
.. _async-timeout: https://pypi.python.org/pypi/async_timeout
.. _cChardet: https://pypi.python.org/pypi/cchardet
License
=======
``aiohttp`` is offered under the Apache 2 license.
Keepsafe
========
The aiohttp community would like to thank Keepsafe
(https://www.getkeepsafe.com) for its support in the early days of
the project.
Source code
===========
The latest developer version is available in a GitHub repository:
https://github.com/aio-libs/aiohttp
Benchmarks
==========
If you are interested in efficiency, the AsyncIO community maintains a
list of benchmarks on the official wiki:
https://github.com/python/asyncio/wiki/Benchmarks
=========
Changelog
=========
..
You should *NOT* be adding new change log entries to this file, this
file is managed by towncrier. You *may* edit previous change logs to
fix problems like typo corrections or such.
To add a new change log entry, please see
https://pip.pypa.io/en/latest/development/#adding-a-news-entry
we named the news folder "changes".
WARNING: Don't drop the next directive!
.. towncrier release notes start
3.6.2 (2019-10-09)
==================
Features
--------
- Made exceptions pickleable. Also changed the repr of some exceptions.
`#4077 <https://github.com/aio-libs/aiohttp/issues/4077>`_
- Use ``Iterable`` type hint instead of ``Sequence`` for ``Application`` *middleware*
parameter. `#4125 <https://github.com/aio-libs/aiohttp/issues/4125>`_
Bugfixes
--------
- Reset the ``sock_read`` timeout each time data is received for a
``aiohttp.ClientResponse``. `#3808
<https://github.com/aio-libs/aiohttp/issues/3808>`_
- Fix handling of expired cookies so they are not stored in CookieJar.
`#4063 <https://github.com/aio-libs/aiohttp/issues/4063>`_
- Fix misleading message in the string representation of ``ClientConnectorError``;
``self.ssl == None`` means default SSL context, not SSL disabled `#4097
<https://github.com/aio-libs/aiohttp/issues/4097>`_
- Don't clobber HTTP status when using FileResponse.
`#4106 <https://github.com/aio-libs/aiohttp/issues/4106>`_
Improved Documentation
----------------------
- Added minimal required logging configuration to logging documentation.
`#2469 <https://github.com/aio-libs/aiohttp/issues/2469>`_
- Update docs to reflect proxy support.
`#4100 <https://github.com/aio-libs/aiohttp/issues/4100>`_
- Fix typo in code example in testing docs.
`#4108 <https://github.com/aio-libs/aiohttp/issues/4108>`_
Misc
----
- `#4102 <https://github.com/aio-libs/aiohttp/issues/4102>`_
----
3.6.1 (2019-09-19)
==================
Features
--------
- Compatibility with Python 3.8.
`#4056 <https://github.com/aio-libs/aiohttp/issues/4056>`_
Bugfixes
--------
- correct some exception string format
`#4068 <https://github.com/aio-libs/aiohttp/issues/4068>`_
- Emit a warning when ``ssl.OP_NO_COMPRESSION`` is
unavailable because the runtime is built against
an outdated OpenSSL.
`#4052 <https://github.com/aio-libs/aiohttp/issues/4052>`_
- Update multidict requirement to >= 4.5
`#4057 <https://github.com/aio-libs/aiohttp/issues/4057>`_
Improved Documentation
----------------------
- Provide pytest-aiohttp namespace for pytest fixtures in docs.
`#3723 <https://github.com/aio-libs/aiohttp/issues/3723>`_
----
3.6.0 (2019-09-06)
==================
Features
--------
- Add support for Named Pipes (Site and Connector) under Windows. This feature requires
Proactor event loop to work. `#3629
<https://github.com/aio-libs/aiohttp/issues/3629>`_
- Removed ``Transfer-Encoding: chunked`` header from websocket responses to be
compatible with more http proxy servers. `#3798
<https://github.com/aio-libs/aiohttp/issues/3798>`_
- Accept non-GET request for starting websocket handshake on server side.
`#3980 <https://github.com/aio-libs/aiohttp/issues/3980>`_
Bugfixes
--------
- Raise a ClientResponseError instead of an AssertionError for a blank
HTTP Reason Phrase.
`#3532 <https://github.com/aio-libs/aiohttp/issues/3532>`_
- Fix an issue where cookies would sometimes not be set during a redirect.
`#3576 <https://github.com/aio-libs/aiohttp/issues/3576>`_
- Change normalize_path_middleware to use 308 redirect instead of 301.
This behavior should prevent clients from being unable to use PUT/POST
methods on endpoints that are redirected because of a trailing slash.
`#3579 <https://github.com/aio-libs/aiohttp/issues/3579>`_
- Drop the processed task from ``all_tasks()`` list early. It prevents logging about a
task with unhandled exception when the server is used in conjunction with
``asyncio.run()``. `#3587 <https://github.com/aio-libs/aiohttp/issues/3587>`_
- ``Signal`` type annotation changed from ``Signal[Callable[['TraceConfig'],
Awaitable[None]]]`` to ``Signal[Callable[ClientSession, SimpleNamespace, ...]``.
`#3595 <https://github.com/aio-libs/aiohttp/issues/3595>`_
- Use sanitized URL as Location header in redirects
`#3614 <https://github.com/aio-libs/aiohttp/issues/3614>`_
- Improve typing annotations for multipart.py along with changes required
by mypy in files that references multipart.py.
`#3621 <https://github.com/aio-libs/aiohttp/issues/3621>`_
- Close session created inside ``aiohttp.request`` when unhandled exception occurs
`#3628 <https://github.com/aio-libs/aiohttp/issues/3628>`_
- Cleanup per-chunk data in generic data read. Memory leak fixed.
`#3631 <https://github.com/aio-libs/aiohttp/issues/3631>`_
- Use correct type for add_view and family
`#3633 <https://github.com/aio-libs/aiohttp/issues/3633>`_
- Fix _keepalive field in __slots__ of ``RequestHandler``.
`#3644 <https://github.com/aio-libs/aiohttp/issues/3644>`_
- Properly handle ConnectionResetError, to silence the "Cannot write to closing
transport" exception when clients disconnect uncleanly.
`#3648 <https://github.com/aio-libs/aiohttp/issues/3648>`_
- Suppress pytest warnings due to ``test_utils`` classes
`#3660 <https://github.com/aio-libs/aiohttp/issues/3660>`_
- Fix overshadowing of overlapped sub-application prefixes.
`#3701 <https://github.com/aio-libs/aiohttp/issues/3701>`_
- Fixed return type annotation for WSMessage.json()
`#3720 <https://github.com/aio-libs/aiohttp/issues/3720>`_
- Properly expose TooManyRedirects publicly as documented.
`#3818 <https://github.com/aio-libs/aiohttp/issues/3818>`_
- Fix missing brackets for IPv6 in proxy CONNECT request
`#3841 <https://github.com/aio-libs/aiohttp/issues/3841>`_
- Make the signature of ``aiohttp.test_utils.TestClient.request`` match
``asyncio.ClientSession.request`` according to the docs `#3852
<https://github.com/aio-libs/aiohttp/issues/3852>`_
- Use correct style for re-exported imports, makes mypy ``--strict`` mode happy.
`#3868 <https://github.com/aio-libs/aiohttp/issues/3868>`_
- Fixed type annotation for add_view method of UrlDispatcher to accept any subclass of
View `#3880 <https://github.com/aio-libs/aiohttp/issues/3880>`_
- Made cython HTTP parser set Reason-Phrase of the response to an empty string if it is
missing. `#3906 <https://github.com/aio-libs/aiohttp/issues/3906>`_
- Add URL to the string representation of ClientResponseError.
`#3959 <https://github.com/aio-libs/aiohttp/issues/3959>`_
- Accept ``istr`` keys in ``LooseHeaders`` type hints.
`#3976 <https://github.com/aio-libs/aiohttp/issues/3976>`_
- Fixed race conditions in _resolve_host caching and throttling when tracing is enabled.
`#4013 <https://github.com/aio-libs/aiohttp/issues/4013>`_
- For URLs like "unix://localhost/..." set Host HTTP header to "localhost" instead of
"localhost:None". `#4039 <https://github.com/aio-libs/aiohttp/issues/4039>`_
Improved Documentation
----------------------
- Modify documentation for Background Tasks to remove deprecated usage of event loop.
`#3526 <https://github.com/aio-libs/aiohttp/issues/3526>`_
- use ``if __name__ == '__main__':`` in server examples.
`#3775 <https://github.com/aio-libs/aiohttp/issues/3775>`_
- Update documentation reference to the default access logger.
`#3783 <https://github.com/aio-libs/aiohttp/issues/3783>`_
- Improve documentation for ``web.BaseRequest.path`` and ``web.BaseRequest.raw_path``.
`#3791 <https://github.com/aio-libs/aiohttp/issues/3791>`_
- Removed deprecation warning in tracing example docs
`#3964 <https://github.com/aio-libs/aiohttp/issues/3964>`_
----
3.5.4 (2019-01-12)
==================
Bugfixes
--------
- Fix stream ``.read()`` / ``.readany()`` / ``.iter_any()`` which used to return a
partial content only in case of compressed content
`#3525 <https://github.com/aio-libs/aiohttp/issues/3525>`_
3.5.3 (2019-01-10)
==================
Bugfixes
--------
- Fix type stubs for ``aiohttp.web.run_app(access_log=True)`` and fix edge case of
``access_log=True`` and the event loop being in debug mode. `#3504
<https://github.com/aio-libs/aiohttp/issues/3504>`_
- Fix ``aiohttp.ClientTimeout`` type annotations to accept ``None`` for fields
`#3511 <https://github.com/aio-libs/aiohttp/issues/3511>`_
- Send custom per-request cookies even if session jar is empty
`#3515 <https://github.com/aio-libs/aiohttp/issues/3515>`_
- Restore Linux binary wheels publishing on PyPI
----
3.5.2 (2019-01-08)
==================
Features
--------
- ``FileResponse`` from ``web_fileresponse.py`` uses a ``ThreadPoolExecutor`` to work
with files asynchronously. I/O based payloads from ``payload.py`` uses a
``ThreadPoolExecutor`` to work with I/O objects asynchronously. `#3313
<https://github.com/aio-libs/aiohttp/issues/3313>`_
- Internal Server Errors in plain text if the browser does not support HTML.
`#3483 <https://github.com/aio-libs/aiohttp/issues/3483>`_
Bugfixes
--------
- Preserve MultipartWriter parts headers on write. Refactor the way how
``Payload.headers`` are handled. Payload instances now always have headers and
Content-Type defined. Fix Payload Content-Disposition header reset after initial
creation. `#3035 <https://github.com/aio-libs/aiohttp/issues/3035>`_
- Log suppressed exceptions in ``GunicornWebWorker``.
`#3464 <https://github.com/aio-libs/aiohttp/issues/3464>`_
- Remove wildcard imports.
`#3468 <https://github.com/aio-libs/aiohttp/issues/3468>`_
- Use the same task for app initialization and web server handling in gunicorn workers.
It allows to use Python3.7 context vars smoothly.
`#3471 <https://github.com/aio-libs/aiohttp/issues/3471>`_
- Fix handling of chunked+gzipped response when first chunk does not give uncompressed
data `#3477 <https://github.com/aio-libs/aiohttp/issues/3477>`_
- Replace ``collections.MutableMapping`` with ``collections.abc.MutableMapping`` to
avoid a deprecation warning. `#3480
<https://github.com/aio-libs/aiohttp/issues/3480>`_
- ``Payload.size`` type annotation changed from ``Optional[float]`` to
``Optional[int]``. `#3484 <https://github.com/aio-libs/aiohttp/issues/3484>`_
- Ignore done tasks when cancels pending activities on ``web.run_app`` finalization.
`#3497 <https://github.com/aio-libs/aiohttp/issues/3497>`_
Improved Documentation
----------------------
- Add documentation for ``aiohttp.web.HTTPException``.
`#3490 <https://github.com/aio-libs/aiohttp/issues/3490>`_
Misc
----
- `#3487 <https://github.com/aio-libs/aiohttp/issues/3487>`_
----
3.5.1 (2018-12-24)
====================
- Fix a regression about ``ClientSession._requote_redirect_url`` modification in debug
mode.
3.5.0 (2018-12-22)
====================
Features
--------
- The library type annotations are checked in strict mode now.
- Add support for setting cookies for individual request (`#2387
<https://github.com/aio-libs/aiohttp/pull/2387>`_)
- Application.add_domain implementation (`#2809
<https://github.com/aio-libs/aiohttp/pull/2809>`_)
- The default ``app`` in the request returned by ``test_utils.make_mocked_request`` can
now have objects assigned to it and retrieved using the ``[]`` operator. (`#3174
<https://github.com/aio-libs/aiohttp/pull/3174>`_)
- Make ``request.url`` accessible when transport is closed. (`#3177
<https://github.com/aio-libs/aiohttp/pull/3177>`_)
- Add ``zlib_executor_size`` argument to ``Response`` constructor to allow compression
to run in a background executor to avoid blocking the main thread and potentially
triggering health check failures. (`#3205
<https://github.com/aio-libs/aiohttp/pull/3205>`_)
- Enable users to set ``ClientTimeout`` in ``aiohttp.request`` (`#3213
<https://github.com/aio-libs/aiohttp/pull/3213>`_)
- Don't raise a warning if ``NETRC`` environment variable is not set and ``~/.netrc``
file doesn't exist. (`#3267 <https://github.com/aio-libs/aiohttp/pull/3267>`_)
- Add default logging handler to web.run_app If the ``Application.debug``` flag is set
and the default logger ``aiohttp.access`` is used, access logs will now be output
using a *stderr* ``StreamHandler`` if no handlers are attached. Furthermore, if the
default logger has no log level set, the log level will be set to ``DEBUG``. (`#3324
<https://github.com/aio-libs/aiohttp/pull/3324>`_)
- Add method argument to ``session.ws_connect()``. Sometimes server API requires a
different HTTP method for WebSocket connection establishment. For example, ``Docker
exec`` needs POST. (`#3378 <https://github.com/aio-libs/aiohttp/pull/3378>`_)
- Create a task per request handling. (`#3406
<https://github.com/aio-libs/aiohttp/pull/3406>`_)
Bugfixes
--------
- Enable passing ``access_log_class`` via ``handler_args`` (`#3158
<https://github.com/aio-libs/aiohttp/pull/3158>`_)
- Return empty bytes with end-of-chunk marker in empty stream reader. (`#3186
<https://github.com/aio-libs/aiohttp/pull/3186>`_)
- Accept ``CIMultiDictProxy`` instances for ``headers`` argument in ``web.Response``
constructor. (`#3207 <https://github.com/aio-libs/aiohttp/pull/3207>`_)
- Don't uppercase HTTP method in parser (`#3233
<https://github.com/aio-libs/aiohttp/pull/3233>`_)
- Make method match regexp RFC-7230 compliant (`#3235
<https://github.com/aio-libs/aiohttp/pull/3235>`_)
- Add ``app.pre_frozen`` state to properly handle startup signals in
sub-applications. (`#3237 <https://github.com/aio-libs/aiohttp/pull/3237>`_)
- Enhanced parsing and validation of helpers.BasicAuth.decode. (`#3239
<https://github.com/aio-libs/aiohttp/pull/3239>`_)
- Change imports from collections module in preparation for 3.8. (`#3258
<https://github.com/aio-libs/aiohttp/pull/3258>`_)
- Ensure Host header is added first to ClientRequest to better replicate browser (`#3265
<https://github.com/aio-libs/aiohttp/pull/3265>`_)
- Fix forward compatibility with Python 3.8: importing ABCs directly from the
collections module will not be supported anymore. (`#3273
<https://github.com/aio-libs/aiohttp/pull/3273>`_)
- Keep the query string by ``normalize_path_middleware``. (`#3278
<https://github.com/aio-libs/aiohttp/pull/3278>`_)
- Fix missing parameter ``raise_for_status`` for aiohttp.request() (`#3290
<https://github.com/aio-libs/aiohttp/pull/3290>`_)
- Bracket IPv6 addresses in the HOST header (`#3304
<https://github.com/aio-libs/aiohttp/pull/3304>`_)
- Fix default message for server ping and pong frames. (`#3308
<https://github.com/aio-libs/aiohttp/pull/3308>`_)
- Fix tests/test_connector.py typo and tests/autobahn/server.py duplicate loop
def. (`#3337 <https://github.com/aio-libs/aiohttp/pull/3337>`_)
- Fix false-negative indicator end_of_HTTP_chunk in StreamReader.readchunk function
(`#3361 <https://github.com/aio-libs/aiohttp/pull/3361>`_)
- Release HTTP response before raising status exception (`#3364
<https://github.com/aio-libs/aiohttp/pull/3364>`_)
- Fix task cancellation when ``sendfile()`` syscall is used by static file
handling. (`#3383 <https://github.com/aio-libs/aiohttp/pull/3383>`_)
- Fix stack trace for ``asyncio.TimeoutError`` which was not logged, when it is caught
in the handler. (`#3414 <https://github.com/aio-libs/aiohttp/pull/3414>`_)
Improved Documentation
----------------------
- Improve documentation of ``Application.make_handler`` parameters. (`#3152
<https://github.com/aio-libs/aiohttp/pull/3152>`_)
- Fix BaseRequest.raw_headers doc. (`#3215
<https://github.com/aio-libs/aiohttp/pull/3215>`_)
- Fix typo in TypeError exception reason in ``web.Application._handle`` (`#3229
<https://github.com/aio-libs/aiohttp/pull/3229>`_)
- Make server access log format placeholder %b documentation reflect
behavior and docstring. (`#3307 <https://github.com/aio-libs/aiohttp/pull/3307>`_)
Deprecations and Removals
-------------------------
- Deprecate modification of ``session.requote_redirect_url`` (`#2278
<https://github.com/aio-libs/aiohttp/pull/2278>`_)
- Deprecate ``stream.unread_data()`` (`#3260
<https://github.com/aio-libs/aiohttp/pull/3260>`_)
- Deprecated use of boolean in ``resp.enable_compression()`` (`#3318
<https://github.com/aio-libs/aiohttp/pull/3318>`_)
- Encourage creation of aiohttp public objects inside a coroutine (`#3331
<https://github.com/aio-libs/aiohttp/pull/3331>`_)
- Drop dead ``Connection.detach()`` and ``Connection.writer``. Both methods were broken
for more than 2 years. (`#3358 <https://github.com/aio-libs/aiohttp/pull/3358>`_)
- Deprecate ``app.loop``, ``request.loop``, ``client.loop`` and ``connector.loop``
properties. (`#3374 <https://github.com/aio-libs/aiohttp/pull/3374>`_)
- Deprecate explicit debug argument. Use asyncio debug mode instead. (`#3381
<https://github.com/aio-libs/aiohttp/pull/3381>`_)
- Deprecate body parameter in HTTPException (and derived classes) constructor. (`#3385
<https://github.com/aio-libs/aiohttp/pull/3385>`_)
- Deprecate bare connector close, use ``async with connector:`` and ``await
connector.close()`` instead. (`#3417
<https://github.com/aio-libs/aiohttp/pull/3417>`_)
- Deprecate obsolete ``read_timeout`` and ``conn_timeout`` in ``ClientSession``
constructor. (`#3438 <https://github.com/aio-libs/aiohttp/pull/3438>`_)
Misc
----
- #3341, #3351

@ -1,124 +0,0 @@
aiohttp-3.6.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
aiohttp-3.6.2.dist-info/LICENSE.txt,sha256=atcq6P9K6Td0Wq4oBfNDqYf6o6YGrHLGCfLUj3GZspQ,11533
aiohttp-3.6.2.dist-info/METADATA,sha256=4kebVhrza_aP2QNEcLfPESEhoVd7Jc1une-JuWJlVlE,24410
aiohttp-3.6.2.dist-info/RECORD,,
aiohttp-3.6.2.dist-info/WHEEL,sha256=uQaeujkjkt7SlmOZGXO6onhwBPrzw2WTI2otbCZzdNI,106
aiohttp-3.6.2.dist-info/top_level.txt,sha256=iv-JIaacmTl-hSho3QmphcKnbRRYx1st47yjz_178Ro,8
aiohttp/__init__.py,sha256=k5JorjbCoRvIyRSvcz-N_LFgNe1wX5HtjLCwNkC7zdY,8427
aiohttp/__pycache__/__init__.cpython-36.pyc,,
aiohttp/__pycache__/abc.cpython-36.pyc,,
aiohttp/__pycache__/base_protocol.cpython-36.pyc,,
aiohttp/__pycache__/client.cpython-36.pyc,,
aiohttp/__pycache__/client_exceptions.cpython-36.pyc,,
aiohttp/__pycache__/client_proto.cpython-36.pyc,,
aiohttp/__pycache__/client_reqrep.cpython-36.pyc,,
aiohttp/__pycache__/client_ws.cpython-36.pyc,,
aiohttp/__pycache__/connector.cpython-36.pyc,,
aiohttp/__pycache__/cookiejar.cpython-36.pyc,,
aiohttp/__pycache__/formdata.cpython-36.pyc,,
aiohttp/__pycache__/frozenlist.cpython-36.pyc,,
aiohttp/__pycache__/hdrs.cpython-36.pyc,,
aiohttp/__pycache__/helpers.cpython-36.pyc,,
aiohttp/__pycache__/http.cpython-36.pyc,,
aiohttp/__pycache__/http_exceptions.cpython-36.pyc,,
aiohttp/__pycache__/http_parser.cpython-36.pyc,,
aiohttp/__pycache__/http_websocket.cpython-36.pyc,,
aiohttp/__pycache__/http_writer.cpython-36.pyc,,
aiohttp/__pycache__/locks.cpython-36.pyc,,
aiohttp/__pycache__/log.cpython-36.pyc,,
aiohttp/__pycache__/multipart.cpython-36.pyc,,
aiohttp/__pycache__/payload.cpython-36.pyc,,
aiohttp/__pycache__/payload_streamer.cpython-36.pyc,,
aiohttp/__pycache__/pytest_plugin.cpython-36.pyc,,
aiohttp/__pycache__/resolver.cpython-36.pyc,,
aiohttp/__pycache__/signals.cpython-36.pyc,,
aiohttp/__pycache__/streams.cpython-36.pyc,,
aiohttp/__pycache__/tcp_helpers.cpython-36.pyc,,
aiohttp/__pycache__/test_utils.cpython-36.pyc,,
aiohttp/__pycache__/tracing.cpython-36.pyc,,
aiohttp/__pycache__/typedefs.cpython-36.pyc,,
aiohttp/__pycache__/web.cpython-36.pyc,,
aiohttp/__pycache__/web_app.cpython-36.pyc,,
aiohttp/__pycache__/web_exceptions.cpython-36.pyc,,
aiohttp/__pycache__/web_fileresponse.cpython-36.pyc,,
aiohttp/__pycache__/web_log.cpython-36.pyc,,
aiohttp/__pycache__/web_middlewares.cpython-36.pyc,,
aiohttp/__pycache__/web_protocol.cpython-36.pyc,,
aiohttp/__pycache__/web_request.cpython-36.pyc,,
aiohttp/__pycache__/web_response.cpython-36.pyc,,
aiohttp/__pycache__/web_routedef.cpython-36.pyc,,
aiohttp/__pycache__/web_runner.cpython-36.pyc,,
aiohttp/__pycache__/web_server.cpython-36.pyc,,
aiohttp/__pycache__/web_urldispatcher.cpython-36.pyc,,
aiohttp/__pycache__/web_ws.cpython-36.pyc,,
aiohttp/__pycache__/worker.cpython-36.pyc,,
aiohttp/_cparser.pxd,sha256=xvsLl13ZXXyHGyb2Us7WsLncndQrxhyGB4KXnvbsRtQ,4099
aiohttp/_find_header.c,sha256=MOZn07_ot-UcOdQBpYAWQmyigqLvMwkqa_7l4M7D1dI,199932
aiohttp/_find_header.h,sha256=HistyxY7K3xEJ53Y5xEfwrDVDkfcV0zQ9mkzMgzi_jo,184
aiohttp/_find_header.pxd,sha256=BFUSmxhemBtblqxzjzH3x03FfxaWlTyuAIOz8YZ5_nM,70
aiohttp/_frozenlist.c,sha256=-vfgzV6cNjUykuqt1kkWDiT2U92BR2zhL9b9yDiiodg,288943
aiohttp/_frozenlist.cp36-win_amd64.pyd,sha256=SN72FLXG8KJYhqgT9BtULfLFhjSmvv_C-oDeQPhlpH8,79872
aiohttp/_frozenlist.pyx,sha256=SB851KmtWpiJ2ZB05Tpo4855VkCyRtgMs843Wz8kFeg,2713
aiohttp/_headers.pxi,sha256=PxiakDsuEs0O94eHRlPcerO24TqPAxc0BtX86XZL4gw,2111
aiohttp/_helpers.c,sha256=sQcHpEGAX3jEvA8jujh4_D_fev9cRjMAc5CySqtHYrg,208657
aiohttp/_helpers.cp36-win_amd64.pyd,sha256=ezuDwotCokL_pvZWHfe9kppSqetibStK3Ob727IJaGY,59904
aiohttp/_helpers.pyi,sha256=C6Q4W8EwElvD1gF1siRGMVG7evEX8fWWstZHL1BbsDA,212
aiohttp/_helpers.pyx,sha256=tgl7fZh0QMT6cjf4jSJ8iaO6DdQD3GON2-SH4N5_ETg,1084
aiohttp/_http_parser.c,sha256=W1sETtDrrBdnBiSOpqaDcO9DcE9zhyLjPTq4WKIK0bc,997494
aiohttp/_http_parser.cp36-win_amd64.pyd,sha256=E54uSiDD1EJj7fCWuOxxqGJKzvCif6HV5ewK1US3ya8,255488
aiohttp/_http_parser.pyx,sha256=C2XxooYRput7XPQzbaGMDrtvJtmhWa58SDPytyuAwGk,29577
aiohttp/_http_writer.c,sha256=-wuBZwiaUXEy1Zj-R5BD5igH7cUg_CYb5ZvYMsh8vzo,211620
aiohttp/_http_writer.cp36-win_amd64.pyd,sha256=wsDiKyfAERR76tMESHKZ9xsEABBowsdYWKjvF7xv2fs,51712
aiohttp/_http_writer.pyx,sha256=TzCawCBLMe7w9eX2SEcUcLYySwkFfrfjaEYHS0Uvjtg,4353
aiohttp/_websocket.c,sha256=JrG6bXW3OR8sfxl5V1Q3VTXvGBbFTYgzgdbhQHr3LGI,136606
aiohttp/_websocket.cp36-win_amd64.pyd,sha256=JvOl8VKDwvfhr3TDGovNSUYK_8smCphWhewuKzk4l1Y,39424
aiohttp/_websocket.pyx,sha256=Ig8jXl_wkAXPugEWS0oPYo0-BnL8zT7uBG6BrYqVXdA,1613
aiohttp/abc.py,sha256=s3wtDI3os8uX4FdQbsvJwr67cFGhylif0mR5k2SKY04,5600
aiohttp/base_protocol.py,sha256=5PJImwc0iX8kR3VjZn1D_SAeL-6JKERi87iGHEYjJQ4,2744
aiohttp/client.py,sha256=DYv-h8V2wljt4hRmPDmU2czk9zSlSn8zua9MgssSEiY,45130
aiohttp/client_exceptions.py,sha256=RCbzCGw_HcaqnL4AHf3nol32xH_2xu1hrYbLNgpjHqk,8786
aiohttp/client_proto.py,sha256=XDXJ0G9RW8m80wHahzjgp4T5S3Rf6LSYks9Q9MajSQg,8276
aiohttp/client_reqrep.py,sha256=zf6GFaDYvpy50HZ4GntrT8flcc6B4HfwnlHw_yYdGMw,37064
aiohttp/client_ws.py,sha256=OUkkw9RwRHRmAakBibE6c63VLMWGVgoyRadoC22wtNY,10995
aiohttp/connector.py,sha256=pbq2XHrujiyQXbIhzXQK6E1zrzRYedzt8xlGNmvbQcM,43672
aiohttp/cookiejar.py,sha256=lNwvnGX3BjIDU4btE50AUsBQditLXzJhsPPUMZo-dkI,12249
aiohttp/formdata.py,sha256=1yNFnS6O0wUrIL4_V66-DwyjS3nWVd0JiPIjWKbTZTs,5957
aiohttp/frozenlist.py,sha256=PSElO5biFCVHyVEr6-hXy7--cDaHAxaWDrsFxESHsFc,1853
aiohttp/frozenlist.pyi,sha256=z-EGiL4Q5MTe1wxDZINsIhqh4Eb0oT9Xn0X_Rt7C9ns,1512
aiohttp/hdrs.py,sha256=PmN2SUiMmwiC0TMEEMSFfwirUpnrzy3jwUhniPGFlmc,3549
aiohttp/helpers.py,sha256=yAdG1c-axo7-Vsf3CRaEqb7hU5Ej-FpUgZowGA76f_U,23613
aiohttp/http.py,sha256=H9xNqvagxteFvx2R7AeYiGfze7uR6VKF5IsUAITr7d4,2183
aiohttp/http_exceptions.py,sha256=Oby70EpyDmwpsb4DpCFYXw-sa856HmWv8IjeHlWWlJo,2771
aiohttp/http_parser.py,sha256=Ttk5BSX11cXMaFJmquzd1oNkZbnodghQvBgdUGdQxnE,28676
aiohttp/http_websocket.py,sha256=KmHznrwSjtpUgxbFafBg1MaAaCpxGxoK0IL8wDKg9f8,25400
aiohttp/http_writer.py,sha256=VBMPy_AaB7m_keycuu05SCN2S3GVVyY8UCHG-W86Y1w,5411
aiohttp/locks.py,sha256=6DiJHW1eQKXypu1eWXZT3_amPhFBK-jnxdI-_BpYICk,1278
aiohttp/log.py,sha256=qAQMjI6XpX3MOAZATN4HcG0tIceSreR54orlYZaoJ0A,333
aiohttp/multipart.py,sha256=RPXfp5GMauxW19nbBaLAkzgUFKTQ9eMo4XtZ7ItGyo4,33740
aiohttp/payload.py,sha256=lCF_pZvwyBKJGk4OOLYEQhtxUwOW8rsFF0pxisvfBps,14483
aiohttp/payload_streamer.py,sha256=7koj4FVujDGriDIOes48XPp5BK9tsWYyTxJG-3aNaHc,2177
aiohttp/py.typed,sha256=E84IaZyFwfLqvXjOVW4LS6WH7QOaKEFpNh9TFyzHNQc,6
aiohttp/pytest_plugin.py,sha256=1_XNSrZS-czuaNVt4qvRQs-GbIIl8DaLykGpoDlZfhU,11187
aiohttp/resolver.py,sha256=mQvusmMHpS0JekvnX7R1y4aqQ7BIIv3FIkxO5wgv2xQ,3738
aiohttp/signals.py,sha256=I_QAX1S7VbN7KDnNO6CSnAzhzx42AYh2Dto_FC9DQ3k,982
aiohttp/signals.pyi,sha256=pg4KElFcxBNFU-OQpTe2x-7qKJ79bAlemgqE-yaciiU,341
aiohttp/streams.py,sha256=EPM7T5_aJLOXlBTIEeFapIQ1O33KsHTvT-wWH3X0QvQ,21093
aiohttp/tcp_helpers.py,sha256=q9fHztjKbR57sCc4zWoo89QDW88pLT0OpcdHLGcV3Fo,1694
aiohttp/test_utils.py,sha256=_GjrPdE_9v0SxzbM4Tmt8vst-KJPwL2ILM_Rl1jHhi4,21530
aiohttp/tracing.py,sha256=GGhlQDrx5AVwFt33Zl4DvBIoFcR7sXAsgXNxvkd2Uus,13740
aiohttp/typedefs.py,sha256=o4R9uAySHxTzedIfX3UPbD0a5TnD5inc_M-h_4qyC4U,1377
aiohttp/web.py,sha256=KQXp0C__KpeX8nYM3FWl-eoMAmj9LZIbx7YeI39pQco,19940
aiohttp/web_app.py,sha256=dHOhoDoakwdrya0cc6Jl6K723MKGmd_M5LxH3wDeGQI,17779
aiohttp/web_exceptions.py,sha256=CQvslnHcpFnreO-qNjnKOWQev7ZvlTG6jfV14NQwb1Q,10519
aiohttp/web_fileresponse.py,sha256=TftBNfbgowCQ0L5Iud-dewCAnXq5tIyP-8iZ-KrSHw8,13118
aiohttp/web_log.py,sha256=gOR8iLbhjeAUwGL-21qD31kA0HlYSNhpdX6eNwJ-3Uo,8490
aiohttp/web_middlewares.py,sha256=jATe_igeeoyBoWKBDW_ISOOzFKvxSoLJE1QPTqZPWGc,4310
aiohttp/web_protocol.py,sha256=Zol5oVApIE12NDLBV_W1oKW8AN-sGdBfC0RFMI050U0,22791
aiohttp/web_request.py,sha256=xzvj84uGe5Uuug1b4iKWZl8uko_0TpzYKa00POke_NM,26526
aiohttp/web_response.py,sha256=CEx04E7NLNg6mfgTjT0QPS9vJuglbw3UQvwob6Qeb7c,26202
aiohttp/web_routedef.py,sha256=5QCl85zQml2qoj7bkC9XMoK4stBVuUoiq_0uefxifjc,6293
aiohttp/web_runner.py,sha256=ArW4NjMJ24Fv68Ez-9hPL1WNzVygDYEWJ4aIfzOMKz8,11479
aiohttp/web_server.py,sha256=P826xDCDs4VgeksMam8OHKm_VzprXuOpsJrysqj3CVg,2222
aiohttp/web_urldispatcher.py,sha256=8uhNNXlHd2WJfJ4wcyQ1UxoRM1VUyWWwQhK-TPrM_GM,40043
aiohttp/web_ws.py,sha256=mAU6Ln3AbMZeXjUZSSA5MmE39hTajJIMxBE0xnq-4Tc,17414
aiohttp/worker.py,sha256=yatPZxpUOp9CzDA05Jb2UWi0eo2PgGWlQm4lIFCRCSY,8420

@ -1,5 +0,0 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.33.6)
Root-Is-Purelib: false
Tag: cp36-cp36m-win_amd64

@ -1,226 +1,41 @@
__version__ = '3.6.2'
__version__ = '1.0.5'
from typing import Tuple # noqa
# Deprecated, keep it here for a while for backward compatibility.
import multidict # noqa
from . import hdrs as hdrs
from .client import BaseConnector as BaseConnector
from .client import ClientConnectionError as ClientConnectionError
from .client import (
ClientConnectorCertificateError as ClientConnectorCertificateError,
)
from .client import ClientConnectorError as ClientConnectorError
from .client import ClientConnectorSSLError as ClientConnectorSSLError
from .client import ClientError as ClientError
from .client import ClientHttpProxyError as ClientHttpProxyError
from .client import ClientOSError as ClientOSError
from .client import ClientPayloadError as ClientPayloadError
from .client import ClientProxyConnectionError as ClientProxyConnectionError
from .client import ClientRequest as ClientRequest
from .client import ClientResponse as ClientResponse
from .client import ClientResponseError as ClientResponseError
from .client import ClientSession as ClientSession
from .client import ClientSSLError as ClientSSLError
from .client import ClientTimeout as ClientTimeout
from .client import ClientWebSocketResponse as ClientWebSocketResponse
from .client import ContentTypeError as ContentTypeError
from .client import Fingerprint as Fingerprint
from .client import InvalidURL as InvalidURL
from .client import NamedPipeConnector as NamedPipeConnector
from .client import RequestInfo as RequestInfo
from .client import ServerConnectionError as ServerConnectionError
from .client import ServerDisconnectedError as ServerDisconnectedError
from .client import ServerFingerprintMismatch as ServerFingerprintMismatch
from .client import ServerTimeoutError as ServerTimeoutError
from .client import TCPConnector as TCPConnector
from .client import TooManyRedirects as TooManyRedirects
from .client import UnixConnector as UnixConnector
from .client import WSServerHandshakeError as WSServerHandshakeError
from .client import request as request
from .cookiejar import CookieJar as CookieJar
from .cookiejar import DummyCookieJar as DummyCookieJar
from .formdata import FormData as FormData
from .helpers import BasicAuth as BasicAuth
from .helpers import ChainMapProxy as ChainMapProxy
from .http import HttpVersion as HttpVersion
from .http import HttpVersion10 as HttpVersion10
from .http import HttpVersion11 as HttpVersion11
from .http import WebSocketError as WebSocketError
from .http import WSCloseCode as WSCloseCode
from .http import WSMessage as WSMessage
from .http import WSMsgType as WSMsgType
from .multipart import (
BadContentDispositionHeader as BadContentDispositionHeader,
)
from .multipart import BadContentDispositionParam as BadContentDispositionParam
from .multipart import BodyPartReader as BodyPartReader
from .multipart import MultipartReader as MultipartReader
from .multipart import MultipartWriter as MultipartWriter
from .multipart import (
content_disposition_filename as content_disposition_filename,
)
from .multipart import parse_content_disposition as parse_content_disposition
from .payload import PAYLOAD_REGISTRY as PAYLOAD_REGISTRY
from .payload import AsyncIterablePayload as AsyncIterablePayload
from .payload import BufferedReaderPayload as BufferedReaderPayload
from .payload import BytesIOPayload as BytesIOPayload
from .payload import BytesPayload as BytesPayload
from .payload import IOBasePayload as IOBasePayload
from .payload import JsonPayload as JsonPayload
from .payload import Payload as Payload
from .payload import StringIOPayload as StringIOPayload
from .payload import StringPayload as StringPayload
from .payload import TextIOPayload as TextIOPayload
from .payload import get_payload as get_payload
from .payload import payload_type as payload_type
from .payload_streamer import streamer as streamer
from .resolver import AsyncResolver as AsyncResolver
from .resolver import DefaultResolver as DefaultResolver
from .resolver import ThreadedResolver as ThreadedResolver
from .signals import Signal as Signal
from .streams import EMPTY_PAYLOAD as EMPTY_PAYLOAD
from .streams import DataQueue as DataQueue
from .streams import EofStream as EofStream
from .streams import FlowControlDataQueue as FlowControlDataQueue
from .streams import StreamReader as StreamReader
from .tracing import TraceConfig as TraceConfig
from .tracing import (
TraceConnectionCreateEndParams as TraceConnectionCreateEndParams,
)
from .tracing import (
TraceConnectionCreateStartParams as TraceConnectionCreateStartParams,
)
from .tracing import (
TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams,
)
from .tracing import (
TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams,
)
from .tracing import (
TraceConnectionReuseconnParams as TraceConnectionReuseconnParams,
)
from .tracing import TraceDnsCacheHitParams as TraceDnsCacheHitParams
from .tracing import TraceDnsCacheMissParams as TraceDnsCacheMissParams
from .tracing import (
TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams,
)
from .tracing import (
TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams,
)
from .tracing import TraceRequestChunkSentParams as TraceRequestChunkSentParams
from .tracing import TraceRequestEndParams as TraceRequestEndParams
from .tracing import TraceRequestExceptionParams as TraceRequestExceptionParams
from .tracing import TraceRequestRedirectParams as TraceRequestRedirectParams
from .tracing import TraceRequestStartParams as TraceRequestStartParams
from .tracing import (
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
)
# This relies on each of the submodules having an __all__ variable.
__all__ = (
'hdrs',
# client
'BaseConnector',
'ClientConnectionError',
'ClientConnectorCertificateError',
'ClientConnectorError',
'ClientConnectorSSLError',
'ClientError',
'ClientHttpProxyError',
'ClientOSError',
'ClientPayloadError',
'ClientProxyConnectionError',
'ClientResponse',
'ClientRequest',
'ClientResponseError',
'ClientSSLError',
'ClientSession',
'ClientTimeout',
'ClientWebSocketResponse',
'ContentTypeError',
'Fingerprint',
'InvalidURL',
'RequestInfo',
'ServerConnectionError',
'ServerDisconnectedError',
'ServerFingerprintMismatch',
'ServerTimeoutError',
'TCPConnector',
'TooManyRedirects',
'UnixConnector',
'NamedPipeConnector',
'WSServerHandshakeError',
'request',
# cookiejar
'CookieJar',
'DummyCookieJar',
# formdata
'FormData',
# helpers
'BasicAuth',
'ChainMapProxy',
# http
'HttpVersion',
'HttpVersion10',
'HttpVersion11',
'WSMsgType',
'WSCloseCode',
'WSMessage',
'WebSocketError',
# multipart
'BadContentDispositionHeader',
'BadContentDispositionParam',
'BodyPartReader',
'MultipartReader',
'MultipartWriter',
'content_disposition_filename',
'parse_content_disposition',
# payload
'AsyncIterablePayload',
'BufferedReaderPayload',
'BytesIOPayload',
'BytesPayload',
'IOBasePayload',
'JsonPayload',
'PAYLOAD_REGISTRY',
'Payload',
'StringIOPayload',
'StringPayload',
'TextIOPayload',
'get_payload',
'payload_type',
# payload_streamer
'streamer',
# resolver
'AsyncResolver',
'DefaultResolver',
'ThreadedResolver',
# signals
'Signal',
'DataQueue',
'EMPTY_PAYLOAD',
'EofStream',
'FlowControlDataQueue',
'StreamReader',
# tracing
'TraceConfig',
'TraceConnectionCreateEndParams',
'TraceConnectionCreateStartParams',
'TraceConnectionQueuedEndParams',
'TraceConnectionQueuedStartParams',
'TraceConnectionReuseconnParams',
'TraceDnsCacheHitParams',
'TraceDnsCacheMissParams',
'TraceDnsResolveHostEndParams',
'TraceDnsResolveHostStartParams',
'TraceRequestChunkSentParams',
'TraceRequestEndParams',
'TraceRequestExceptionParams',
'TraceRequestRedirectParams',
'TraceRequestStartParams',
'TraceResponseChunkReceivedParams',
) # type: Tuple[str, ...]
from multidict import * # noqa
from . import hdrs # noqa
from .protocol import * # noqa
from .connector import * # noqa
from .client import * # noqa
from .client_reqrep import * # noqa
from .errors import * # noqa
from .helpers import * # noqa
from .parsers import * # noqa
from .streams import * # noqa
from .multipart import * # noqa
from .client_ws import ClientWebSocketResponse # noqa
from ._ws_impl import WSMsgType, WSCloseCode, WSMessage, WebSocketError # noqa
from .file_sender import FileSender # noqa
from .cookiejar import CookieJar # noqa
from .resolver import * # noqa
try:
from .worker import GunicornWebWorker, GunicornUVLoopWebWorker # noqa
__all__ += ('GunicornWebWorker', 'GunicornUVLoopWebWorker')
except ImportError: # pragma: no cover
pass
MsgType = WSMsgType # backward compatibility
__all__ = (client.__all__ + # noqa
client_reqrep.__all__ + # noqa
errors.__all__ + # noqa
helpers.__all__ + # noqa
parsers.__all__ + # noqa
protocol.__all__ + # noqa
connector.__all__ + # noqa
streams.__all__ + # noqa
multidict.__all__ + # noqa
multipart.__all__ + # noqa
('hdrs', 'FileSender', 'WSMsgType', 'MsgType', 'WSCloseCode',
'WebSocketError', 'WSMessage',
'ClientWebSocketResponse', 'CookieJar'))

@ -1,140 +0,0 @@
from libc.stdint cimport uint16_t, uint32_t, uint64_t
cdef extern from "../vendor/http-parser/http_parser.h":
ctypedef int (*http_data_cb) (http_parser*,
const char *at,
size_t length) except -1
ctypedef int (*http_cb) (http_parser*) except -1
struct http_parser:
unsigned int type
unsigned int flags
unsigned int state
unsigned int header_state
unsigned int index
uint32_t nread
uint64_t content_length
unsigned short http_major
unsigned short http_minor
unsigned int status_code
unsigned int method
unsigned int http_errno
unsigned int upgrade
void *data
struct http_parser_settings:
http_cb on_message_begin
http_data_cb on_url
http_data_cb on_status
http_data_cb on_header_field
http_data_cb on_header_value
http_cb on_headers_complete
http_data_cb on_body
http_cb on_message_complete
http_cb on_chunk_header
http_cb on_chunk_complete
enum http_parser_type:
HTTP_REQUEST,
HTTP_RESPONSE,
HTTP_BOTH
enum http_errno:
HPE_OK,
HPE_CB_message_begin,
HPE_CB_url,
HPE_CB_header_field,
HPE_CB_header_value,
HPE_CB_headers_complete,
HPE_CB_body,
HPE_CB_message_complete,
HPE_CB_status,
HPE_CB_chunk_header,
HPE_CB_chunk_complete,
HPE_INVALID_EOF_STATE,
HPE_HEADER_OVERFLOW,
HPE_CLOSED_CONNECTION,
HPE_INVALID_VERSION,
HPE_INVALID_STATUS,
HPE_INVALID_METHOD,
HPE_INVALID_URL,
HPE_INVALID_HOST,
HPE_INVALID_PORT,
HPE_INVALID_PATH,
HPE_INVALID_QUERY_STRING,
HPE_INVALID_FRAGMENT,
HPE_LF_EXPECTED,
HPE_INVALID_HEADER_TOKEN,
HPE_INVALID_CONTENT_LENGTH,
HPE_INVALID_CHUNK_SIZE,
HPE_INVALID_CONSTANT,
HPE_INVALID_INTERNAL_STATE,
HPE_STRICT,
HPE_PAUSED,
HPE_UNKNOWN
enum flags:
F_CHUNKED,
F_CONNECTION_KEEP_ALIVE,
F_CONNECTION_CLOSE,
F_CONNECTION_UPGRADE,
F_TRAILING,
F_UPGRADE,
F_SKIPBODY,
F_CONTENTLENGTH
enum http_method:
DELETE, GET, HEAD, POST, PUT, CONNECT, OPTIONS, TRACE, COPY,
LOCK, MKCOL, MOVE, PROPFIND, PROPPATCH, SEARCH, UNLOCK, BIND,
REBIND, UNBIND, ACL, REPORT, MKACTIVITY, CHECKOUT, MERGE,
MSEARCH, NOTIFY, SUBSCRIBE, UNSUBSCRIBE, PATCH, PURGE, MKCALENDAR,
LINK, UNLINK
void http_parser_init(http_parser *parser, http_parser_type type)
size_t http_parser_execute(http_parser *parser,
const http_parser_settings *settings,
const char *data,
size_t len)
int http_should_keep_alive(const http_parser *parser)
void http_parser_settings_init(http_parser_settings *settings)
const char *http_errno_name(http_errno err)
const char *http_errno_description(http_errno err)
const char *http_method_str(http_method m)
# URL Parser
enum http_parser_url_fields:
UF_SCHEMA = 0,
UF_HOST = 1,
UF_PORT = 2,
UF_PATH = 3,
UF_QUERY = 4,
UF_FRAGMENT = 5,
UF_USERINFO = 6,
UF_MAX = 7
struct http_parser_url_field_data:
uint16_t off
uint16_t len
struct http_parser_url:
uint16_t field_set
uint16_t port
http_parser_url_field_data[<int>UF_MAX] field_data
void http_parser_url_init(http_parser_url *u)
int http_parser_parse_url(const char *buf,
size_t buflen,
int is_connect,
http_parser_url *u)

File diff suppressed because it is too large Load Diff

@ -1,14 +0,0 @@
#ifndef _FIND_HEADERS_H
#define _FIND_HEADERS_H
#ifdef __cplusplus
extern "C" {
#endif
int find_header(const char *str, int size);
#ifdef __cplusplus
}
#endif
#endif

@ -1,2 +0,0 @@
cdef extern from "_find_header.h":
int find_header(char *, int)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -6,8 +6,8 @@ cdef extern from "Python.h":
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
def _websocket_mask_cython(object mask, object data):
"""Note, this function mutates its `data` argument
def _websocket_mask_cython(bytes mask, bytearray data):
"""Note, this function mutates it's `data` argument
"""
cdef:
Py_ssize_t data_len, i
@ -19,14 +19,6 @@ def _websocket_mask_cython(object mask, object data):
assert len(mask) == 4
if not isinstance(mask, bytes):
mask = bytes(mask)
if isinstance(data, bytearray):
data = <bytearray>data
else:
data = bytearray(data)
data_len = len(data)
in_buf = <unsigned char*>PyByteArray_AsString(data)
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
@ -52,3 +44,5 @@ def _websocket_mask_cython(object mask, object data):
for i in range(0, data_len):
in_buf[i] ^= mask_buf[i]
return data

@ -0,0 +1,438 @@
"""WebSocket protocol versions 13 and 8."""
import base64
import binascii
import collections
import hashlib
import json
import os
import random
import sys
from enum import IntEnum
from struct import Struct
from aiohttp import errors, hdrs
from aiohttp.log import ws_logger
__all__ = ('WebSocketParser', 'WebSocketWriter', 'do_handshake',
'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode')
class WSCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
INVALID_TEXT = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
class WSMsgType(IntEnum):
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
PING = 0x9
PONG = 0xa
CLOSE = 0x8
CLOSED = 0x101
ERROR = 0x102
text = TEXT
binary = BINARY
ping = PING
pong = PONG
close = CLOSE
closed = CLOSED
error = ERROR
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
UNPACK_LEN2 = Struct('!H').unpack_from
UNPACK_LEN3 = Struct('!Q').unpack_from
UNPACK_CLOSE_CODE = Struct('!H').unpack
PACK_LEN1 = Struct('!BB').pack
PACK_LEN2 = Struct('!BBH').pack
PACK_LEN3 = Struct('!BBQ').pack
PACK_CLOSE_CODE = Struct('!H').pack
MSG_SIZE = 2 ** 14
_WSMessageBase = collections.namedtuple('_WSMessageBase',
['type', 'data', 'extra'])
class WSMessage(_WSMessageBase):
def json(self, *, loads=json.loads):
"""Return parsed JSON data.
.. versionadded:: 0.22
"""
return loads(self.data)
@property
def tp(self):
return self.type
CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
class WebSocketError(Exception):
"""WebSocket protocol parser error."""
def __init__(self, code, message):
self.code = code
super().__init__(message)
def WebSocketParser(out, buf):
while True:
fin, opcode, payload = yield from parse_frame(buf)
if opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close code: {}'.format(close_code))
try:
close_message = payload[2:].decode('utf-8')
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close frame: {} {} {!r}'.format(
fin, opcode, payload))
else:
msg = WSMessage(WSMsgType.CLOSE, 0, '')
out.feed_data(msg, 0)
elif opcode == WSMsgType.PING:
out.feed_data(WSMessage(WSMsgType.PING, payload, ''), len(payload))
elif opcode == WSMsgType.PONG:
out.feed_data(WSMessage(WSMsgType.PONG, payload, ''), len(payload))
elif opcode not in (WSMsgType.TEXT, WSMsgType.BINARY):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Unexpected opcode={!r}".format(opcode))
else:
# load text/binary
data = [payload]
while not fin:
fin, _opcode, payload = yield from parse_frame(buf, True)
# We can receive ping/close in the middle of
# text message, Case 5.*
if _opcode == WSMsgType.PING:
out.feed_data(
WSMessage(WSMsgType.PING, payload, ''), len(payload))
fin, _opcode, payload = yield from parse_frame(buf, True)
elif _opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if (close_code not in ALLOWED_CLOSE_CODES and
close_code < 3000):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close code: {}'.format(close_code))
try:
close_message = payload[2:].decode('utf-8')
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
msg = WSMessage(WSMsgType.CLOSE, close_code,
close_message)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close frame: {} {} {!r}'.format(
fin, opcode, payload))
else:
msg = WSMessage(WSMsgType.CLOSE, 0, '')
out.feed_data(msg, 0)
fin, _opcode, payload = yield from parse_frame(buf, True)
if _opcode != WSMsgType.CONTINUATION:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'The opcode in non-fin frame is expected '
'to be zero, got {!r}'.format(_opcode))
else:
data.append(payload)
if opcode == WSMsgType.TEXT:
try:
text = b''.join(data).decode('utf-8')
out.feed_data(WSMessage(WSMsgType.TEXT, text, ''),
len(text))
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
else:
data = b''.join(data)
out.feed_data(
WSMessage(WSMsgType.BINARY, data, ''), len(data))
native_byteorder = sys.byteorder
def _websocket_mask_python(mask, data):
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytes` object
of any length. Returns a `bytes` object of the same length as
`data` with the mask applied as specified in section 5.3 of RFC
6455.
This pure-python implementation may be replaced by an optimized
version when available.
"""
assert isinstance(data, bytearray), data
assert len(mask) == 4, mask
datalen = len(data)
if datalen == 0:
# everything work without this, but may be changed later in Python.
return bytearray()
data = int.from_bytes(data, native_byteorder)
mask = int.from_bytes(mask * (datalen // 4) + mask[: datalen % 4],
native_byteorder)
return (data ^ mask).to_bytes(datalen, native_byteorder)
if bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')):
_websocket_mask = _websocket_mask_python
else:
try:
from ._websocket import _websocket_mask_cython
_websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
_websocket_mask = _websocket_mask_python
def parse_frame(buf, continuation=False):
"""Return the next frame from the socket."""
# read header
data = yield from buf.read(2)
first_byte, second_byte = data
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xf
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
if rsv1 or rsv2 or rsv3:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received frame with non-zero reserved bits')
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received fragmented control frame')
if fin == 0 and opcode == WSMsgType.CONTINUATION and not continuation:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received new fragment frame with non-zero '
'opcode {!r}'.format(opcode))
has_mask = (second_byte >> 7) & 1
length = (second_byte) & 0x7f
# Control frames MUST have a payload length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes")
# read payload
if length == 126:
data = yield from buf.read(2)
length = UNPACK_LEN2(data)[0]
elif length > 126:
data = yield from buf.read(8)
length = UNPACK_LEN3(data)[0]
if has_mask:
mask = yield from buf.read(4)
if length:
payload = yield from buf.read(length)
else:
payload = bytearray()
if has_mask:
payload = _websocket_mask(bytes(mask), payload)
return fin, opcode, payload
class WebSocketWriter:
def __init__(self, writer, *, use_mask=False, random=random.Random()):
self.writer = writer
self.use_mask = use_mask
self.randrange = random.randrange
def _send_frame(self, message, opcode):
"""Send a frame over the websocket with message as its payload."""
msg_length = len(message)
use_mask = self.use_mask
if use_mask:
mask_bit = 0x80
else:
mask_bit = 0
if msg_length < 126:
header = PACK_LEN1(0x80 | opcode, msg_length | mask_bit)
elif msg_length < (1 << 16):
header = PACK_LEN2(0x80 | opcode, 126 | mask_bit, msg_length)
else:
header = PACK_LEN3(0x80 | opcode, 127 | mask_bit, msg_length)
if use_mask:
mask = self.randrange(0, 0xffffffff)
mask = mask.to_bytes(4, 'big')
message = _websocket_mask(mask, bytearray(message))
self.writer.write(header + mask + message)
else:
if len(message) > MSG_SIZE:
self.writer.write(header)
self.writer.write(message)
else:
self.writer.write(header + message)
def pong(self, message=b''):
"""Send pong message."""
if isinstance(message, str):
message = message.encode('utf-8')
self._send_frame(message, WSMsgType.PONG)
def ping(self, message=b''):
"""Send ping message."""
if isinstance(message, str):
message = message.encode('utf-8')
self._send_frame(message, WSMsgType.PING)
def send(self, message, binary=False):
"""Send a frame over the websocket with message as its payload."""
if isinstance(message, str):
message = message.encode('utf-8')
if binary:
self._send_frame(message, WSMsgType.BINARY)
else:
self._send_frame(message, WSMsgType.TEXT)
def close(self, code=1000, message=b''):
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode('utf-8')
self._send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
def do_handshake(method, headers, transport, protocols=()):
"""Prepare WebSocket handshake.
It return HTTP response code, response headers, websocket parser,
websocket writer. It does not perform any IO.
`protocols` is a sequence of known protocols. On successful handshake,
the returned response headers contain the first protocol in this list
which the server also knows.
"""
# WebSocket accepts only GET
if method.upper() != hdrs.METH_GET:
raise errors.HttpProcessingError(
code=405, headers=((hdrs.ALLOW, hdrs.METH_GET),))
if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip():
raise errors.HttpBadRequest(
message='No WebSocket UPGRADE hdr: {}\n Can '
'"Upgrade" only to "WebSocket".'.format(headers.get(hdrs.UPGRADE)))
if 'upgrade' not in headers.get(hdrs.CONNECTION, '').lower():
raise errors.HttpBadRequest(
message='No CONNECTION upgrade hdr: {}'.format(
headers.get(hdrs.CONNECTION)))
# find common sub-protocol between client and server
protocol = None
if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
req_protocols = [str(proto.strip()) for proto in
headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in req_protocols:
if proto in protocols:
protocol = proto
break
else:
# No overlap found: Return no protocol as per spec
ws_logger.warning(
'Client protocols %r dont overlap server-known ones %r',
req_protocols, protocols)
# check supported version
version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, '')
if version not in ('13', '8', '7'):
raise errors.HttpBadRequest(
message='Unsupported version: {}'.format(version),
headers=((hdrs.SEC_WEBSOCKET_VERSION, '13'),))
# check client handshake for validity
key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
try:
if not key or len(base64.b64decode(key)) != 16:
raise errors.HttpBadRequest(
message='Handshake error: {!r}'.format(key))
except binascii.Error:
raise errors.HttpBadRequest(
message='Handshake error: {!r}'.format(key)) from None
response_headers = [
(hdrs.UPGRADE, 'websocket'),
(hdrs.CONNECTION, 'upgrade'),
(hdrs.TRANSFER_ENCODING, 'chunked'),
(hdrs.SEC_WEBSOCKET_ACCEPT, base64.b64encode(
hashlib.sha1(key.encode() + WS_KEY).digest()).decode())]
if protocol:
response_headers.append((hdrs.SEC_WEBSOCKET_PROTOCOL, protocol))
# response code, headers, parser, writer, protocol
return (101,
response_headers,
WebSocketParser,
WebSocketWriter(transport),
protocol)

@ -1,208 +1,88 @@
import asyncio
import logging
import sys
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel # noqa
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
)
from multidict import CIMultiDict # noqa
from yarl import URL
from .helpers import get_running_loop
from .typedefs import LooseCookies
if TYPE_CHECKING: # pragma: no cover
from .web_request import BaseRequest, Request
from .web_response import StreamResponse
from .web_app import Application
from .web_exceptions import HTTPException
else:
BaseRequest = Request = Application = StreamResponse = None
HTTPException = None
from collections.abc import Iterable, Sized
PY_35 = sys.version_info >= (3, 5)
class AbstractRouter(ABC):
def __init__(self) -> None:
self._frozen = False
def post_init(self, app: Application) -> None:
"""Post init stage.
Not an abstract method for sake of backward compatibility,
but if the router wants to be aware of the application
it can override this.
"""
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
"""Freeze router."""
self._frozen = True
class AbstractRouter(ABC):
@asyncio.coroutine # pragma: no branch
@abstractmethod
async def resolve(self, request: Request) -> 'AbstractMatchInfo':
def resolve(self, request):
"""Return MATCH_INFO for given request"""
class AbstractMatchInfo(ABC):
@property # pragma: no branch
@asyncio.coroutine # pragma: no branch
@abstractmethod
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
def handler(self, request):
"""Execute matched request handler"""
@property
@asyncio.coroutine # pragma: no branch
@abstractmethod
def expect_handler(self) -> Callable[[Request], Awaitable[None]]:
def expect_handler(self, request):
"""Expect handler for 100-continue processing"""
@property # pragma: no branch
@abstractmethod
def http_exception(self) -> Optional[HTTPException]:
def http_exception(self):
"""HTTPException instance raised on router's resolving, or None"""
@abstractmethod # pragma: no branch
def get_info(self) -> Dict[str, Any]:
def get_info(self):
"""Return a dict with additional info useful for introspection"""
@property # pragma: no branch
@abstractmethod
def apps(self) -> Tuple[Application, ...]:
"""Stack of nested applications.
Top level application is left-most element.
"""
@abstractmethod
def add_app(self, app: Application) -> None:
"""Add application to the nested apps stack."""
@abstractmethod
def freeze(self) -> None:
"""Freeze the match info.
The method is called after route resolution.
After the call .add_app() is forbidden.
"""
class AbstractView(ABC):
"""Abstract class based view."""
def __init__(self, request: Request) -> None:
def __init__(self, request):
self._request = request
@property
def request(self) -> Request:
"""Request instance."""
def request(self):
return self._request
@asyncio.coroutine # pragma: no branch
@abstractmethod
def __await__(self) -> Generator[Any, None, StreamResponse]:
"""Execute the view handler."""
def __iter__(self):
while False: # pragma: no cover
yield None
if PY_35: # pragma: no branch
@abstractmethod
def __await__(self):
return # pragma: no cover
class AbstractResolver(ABC):
"""Abstract DNS resolver."""
@asyncio.coroutine # pragma: no branch
@abstractmethod
async def resolve(self, host: str,
port: int, family: int) -> List[Dict[str, Any]]:
def resolve(self, hostname):
"""Return IP address for given hostname"""
@asyncio.coroutine # pragma: no branch
@abstractmethod
async def close(self) -> None:
def close(self):
"""Release resolver"""
if TYPE_CHECKING: # pragma: no cover
IterableBase = Iterable[Morsel[str]]
else:
IterableBase = Iterable
class AbstractCookieJar(Sized, Iterable):
class AbstractCookieJar(Sized, IterableBase):
"""Abstract Cookie Jar."""
def __init__(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
self._loop = get_running_loop(loop)
def __init__(self, *, loop=None):
self._loop = loop or asyncio.get_event_loop()
@abstractmethod
def clear(self) -> None:
def clear(self):
"""Clear all cookies."""
@abstractmethod
def update_cookies(self,
cookies: LooseCookies,
response_url: URL=URL()) -> None:
def update_cookies(self, cookies, response_url=None):
"""Update cookies."""
@abstractmethod
def filter_cookies(self, request_url: URL) -> 'BaseCookie[str]':
def filter_cookies(self, request_url):
"""Return the jar's cookies filtered by their attributes."""
class AbstractStreamWriter(ABC):
"""Abstract stream writer."""
buffer_size = 0
output_size = 0
length = 0 # type: Optional[int]
@abstractmethod
async def write(self, chunk: bytes) -> None:
"""Write chunk into stream."""
@abstractmethod
async def write_eof(self, chunk: bytes=b'') -> None:
"""Write last chunk."""
@abstractmethod
async def drain(self) -> None:
"""Flush the write buffer."""
@abstractmethod
def enable_compression(self, encoding: str='deflate') -> None:
"""Enable HTTP body compression"""
@abstractmethod
def enable_chunking(self) -> None:
"""Enable HTTP chunked mode"""
@abstractmethod
async def write_headers(self, status_line: str,
headers: 'CIMultiDict[str]') -> None:
"""Write HTTP headers"""
class AbstractAccessLogger(ABC):
"""Abstract writer to access log."""
def __init__(self, logger: logging.Logger, log_format: str) -> None:
self.logger = logger
self.log_format = log_format
@abstractmethod
def log(self,
request: BaseRequest,
response: StreamResponse,
time: float) -> None:
"""Emit log to logger."""

@ -1,81 +0,0 @@
import asyncio
from typing import Optional, cast
from .tcp_helpers import tcp_nodelay
class BaseProtocol(asyncio.Protocol):
__slots__ = ('_loop', '_paused', '_drain_waiter',
'_connection_lost', '_reading_paused', 'transport')
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop # type: asyncio.AbstractEventLoop
self._paused = False
self._drain_waiter = None # type: Optional[asyncio.Future[None]]
self._connection_lost = False
self._reading_paused = False
self.transport = None # type: Optional[asyncio.Transport]
def pause_writing(self) -> None:
assert not self._paused
self._paused = True
def resume_writing(self) -> None:
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
def pause_reading(self) -> None:
if not self._reading_paused and self.transport is not None:
try:
self.transport.pause_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = True
def resume_reading(self) -> None:
if self._reading_paused and self.transport is not None:
try:
self.transport.resume_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = False
def connection_made(self, transport: asyncio.BaseTransport) -> None:
tr = cast(asyncio.Transport, transport)
tcp_nodelay(tr, True)
self.transport = tr
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._connection_lost = True
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
async def _drain_helper(self) -> None:
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = self._loop.create_future()
self._drain_waiter = waiter
await waiter

File diff suppressed because it is too large Load Diff

@ -1,292 +0,0 @@
"""HTTP related errors."""
import asyncio
import warnings
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
from .typedefs import _CIMultiDict
try:
import ssl
SSLContext = ssl.SSLContext
except ImportError: # pragma: no cover
ssl = SSLContext = None # type: ignore
if TYPE_CHECKING: # pragma: no cover
from .client_reqrep import (RequestInfo, ClientResponse, ConnectionKey, # noqa
Fingerprint)
else:
RequestInfo = ClientResponse = ConnectionKey = None
__all__ = (
'ClientError',
'ClientConnectionError',
'ClientOSError', 'ClientConnectorError', 'ClientProxyConnectionError',
'ClientSSLError',
'ClientConnectorSSLError', 'ClientConnectorCertificateError',
'ServerConnectionError', 'ServerTimeoutError', 'ServerDisconnectedError',
'ServerFingerprintMismatch',
'ClientResponseError', 'ClientHttpProxyError',
'WSServerHandshakeError', 'ContentTypeError',
'ClientPayloadError', 'InvalidURL')
class ClientError(Exception):
"""Base class for client connection errors."""
class ClientResponseError(ClientError):
"""Connection error during reading response.
request_info: instance of RequestInfo
"""
def __init__(self, request_info: RequestInfo,
history: Tuple[ClientResponse, ...], *,
code: Optional[int]=None,
status: Optional[int]=None,
message: str='',
headers: Optional[_CIMultiDict]=None) -> None:
self.request_info = request_info
if code is not None:
if status is not None:
raise ValueError(
"Both code and status arguments are provided; "
"code is deprecated, use status instead")
warnings.warn("code argument is deprecated, use status instead",
DeprecationWarning,
stacklevel=2)
if status is not None:
self.status = status
elif code is not None:
self.status = code
else:
self.status = 0
self.message = message
self.headers = headers
self.history = history
self.args = (request_info, history)
def __str__(self) -> str:
return ("%s, message=%r, url=%r" %
(self.status, self.message, self.request_info.real_url))
def __repr__(self) -> str:
args = "%r, %r" % (self.request_info, self.history)
if self.status != 0:
args += ", status=%r" % (self.status,)
if self.message != '':
args += ", message=%r" % (self.message,)
if self.headers is not None:
args += ", headers=%r" % (self.headers,)
return "%s(%s)" % (type(self).__name__, args)
@property
def code(self) -> int:
warnings.warn("code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2)
return self.status
@code.setter
def code(self, value: int) -> None:
warnings.warn("code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2)
self.status = value
class ContentTypeError(ClientResponseError):
"""ContentType found is not valid."""
class WSServerHandshakeError(ClientResponseError):
"""websocket server handshake error."""
class ClientHttpProxyError(ClientResponseError):
"""HTTP proxy error.
Raised in :class:`aiohttp.connector.TCPConnector` if
proxy responds with status other than ``200 OK``
on ``CONNECT`` request.
"""
class TooManyRedirects(ClientResponseError):
"""Client was redirected too many times."""
class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
class ClientConnectorError(ClientOSError):
"""Client connector error.
Raised in :class:`aiohttp.connector.TCPConnector` if
connection to proxy can not be established.
"""
def __init__(self, connection_key: ConnectionKey,
os_error: OSError) -> None:
self._conn_key = connection_key
self._os_error = os_error
super().__init__(os_error.errno, os_error.strerror)
self.args = (connection_key, os_error)
@property
def os_error(self) -> OSError:
return self._os_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> Union[SSLContext, None, bool, 'Fingerprint']:
return self._conn_key.ssl
def __str__(self) -> str:
return ('Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]'
.format(self, self.ssl if self.ssl is not None else 'default',
self.strerror))
# OSError.__reduce__ does too much black magick
__reduce__ = BaseException.__reduce__
class ClientProxyConnectionError(ClientConnectorError):
"""Proxy connection error.
Raised in :class:`aiohttp.connector.TCPConnector` if
connection to proxy can not be established.
"""
class ServerConnectionError(ClientConnectionError):
"""Server connection errors."""
class ServerDisconnectedError(ServerConnectionError):
"""Server disconnected."""
def __init__(self, message: Optional[str]=None) -> None:
self.message = message
if message is None:
self.args = ()
else:
self.args = (message,)
class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
"""Server timeout error."""
class ServerFingerprintMismatch(ServerConnectionError):
"""SSL certificate does not match expected fingerprint."""
def __init__(self, expected: bytes, got: bytes,
host: str, port: int) -> None:
self.expected = expected
self.got = got
self.host = host
self.port = port
self.args = (expected, got, host, port)
def __repr__(self) -> str:
return '<{} expected={!r} got={!r} host={!r} port={!r}>'.format(
self.__class__.__name__, self.expected, self.got,
self.host, self.port)
class ClientPayloadError(ClientError):
"""Response payload error."""
class InvalidURL(ClientError, ValueError):
"""Invalid URL.
URL used for fetching is malformed, e.g. it doesn't contains host
part."""
# Derive from ValueError for backward compatibility
def __init__(self, url: Any) -> None:
# The type of url is not yarl.URL because the exception can be raised
# on URL(url) call
super().__init__(url)
@property
def url(self) -> Any:
return self.args[0]
def __repr__(self) -> str:
return '<{} {}>'.format(self.__class__.__name__, self.url)
class ClientSSLError(ClientConnectorError):
"""Base error for ssl.*Errors."""
if ssl is not None:
cert_errors = (ssl.CertificateError,)
cert_errors_bases = (ClientSSLError, ssl.CertificateError,)
ssl_errors = (ssl.SSLError,)
ssl_error_bases = (ClientSSLError, ssl.SSLError)
else: # pragma: no cover
cert_errors = tuple()
cert_errors_bases = (ClientSSLError, ValueError,)
ssl_errors = tuple()
ssl_error_bases = (ClientSSLError,)
class ClientConnectorSSLError(*ssl_error_bases): # type: ignore
"""Response ssl error."""
class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore
"""Response certificate error."""
def __init__(self, connection_key:
ConnectionKey, certificate_error: Exception) -> None:
self._conn_key = connection_key
self._certificate_error = certificate_error
self.args = (connection_key, certificate_error)
@property
def certificate_error(self) -> Exception:
return self._certificate_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> bool:
return self._conn_key.is_ssl
def __str__(self) -> str:
return ('Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} '
'[{0.certificate_error.__class__.__name__}: '
'{0.certificate_error.args}]'.format(self))

@ -1,239 +0,0 @@
import asyncio
from contextlib import suppress
from typing import Any, Optional, Tuple
from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
ServerTimeoutError,
)
from .helpers import BaseTimerContext
from .http import HttpResponseParser, RawResponseMessage
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
class ResponseHandler(BaseProtocol,
DataQueue[Tuple[RawResponseMessage, StreamReader]]):
"""Helper class to adapt between Protocol and StreamReader."""
def __init__(self,
loop: asyncio.AbstractEventLoop) -> None:
BaseProtocol.__init__(self, loop=loop)
DataQueue.__init__(self, loop)
self._should_close = False
self._payload = None
self._skip_payload = False
self._payload_parser = None
self._timer = None
self._tail = b''
self._upgraded = False
self._parser = None # type: Optional[HttpResponseParser]
self._read_timeout = None # type: Optional[float]
self._read_timeout_handle = None # type: Optional[asyncio.TimerHandle]
@property
def upgraded(self) -> bool:
return self._upgraded
@property
def should_close(self) -> bool:
if (self._payload is not None and
not self._payload.is_eof() or self._upgraded):
return True
return (self._should_close or self._upgraded or
self.exception() is not None or
self._payload_parser is not None or
len(self) > 0 or bool(self._tail))
def force_close(self) -> None:
self._should_close = True
def close(self) -> None:
transport = self.transport
if transport is not None:
transport.close()
self.transport = None
self._payload = None
self._drop_timeout()
def is_connected(self) -> bool:
return self.transport is not None
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()
if self._payload_parser is not None:
with suppress(Exception):
self._payload_parser.feed_eof()
uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception:
if self._payload is not None:
self._payload.set_exception(
ClientPayloadError(
'Response payload is not completed'))
if not self.is_eof():
if isinstance(exc, OSError):
exc = ClientOSError(*exc.args)
if exc is None:
exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
self.set_exception(exc)
self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False
super().connection_lost(exc)
def eof_received(self) -> None:
# should call parser.feed_eof() most likely
self._drop_timeout()
def pause_reading(self) -> None:
super().pause_reading()
self._drop_timeout()
def resume_reading(self) -> None:
super().resume_reading()
self._reschedule_timeout()
def set_exception(self, exc: BaseException) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc)
def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: FlowControlDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
self._payload_parser = parser
self._drop_timeout()
if self._tail:
data, self._tail = self._tail, b''
self.data_received(data)
def set_response_params(self, *, timer: BaseTimerContext=None,
skip_payload: bool=False,
read_until_eof: bool=False,
auto_decompress: bool=True,
read_timeout: Optional[float]=None) -> None:
self._skip_payload = skip_payload
self._read_timeout = read_timeout
self._reschedule_timeout()
self._parser = HttpResponseParser(
self, self._loop, timer=timer,
payload_exception=ClientPayloadError,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress)
if self._tail:
data, self._tail = self._tail, b''
self.data_received(data)
def _drop_timeout(self) -> None:
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
self._read_timeout_handle = None
def _reschedule_timeout(self) -> None:
timeout = self._read_timeout
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
if timeout:
self._read_timeout_handle = self._loop.call_later(
timeout, self._on_read_timeout)
else:
self._read_timeout_handle = None
def _on_read_timeout(self) -> None:
exc = ServerTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
self._payload.set_exception(exc)
def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
if not data:
return
# custom payload parser
if self._payload_parser is not None:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self._payload = None
self._payload_parser = None
if tail:
self.data_received(tail)
return
else:
if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data
else:
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
self.set_exception(exc)
return
self._upgraded = upgraded
payload = None
for message, payload in messages:
if message.should_close:
self._should_close = True
self._payload = payload
if self._skip_payload or message.code in (204, 304):
self.feed_data((message, EMPTY_PAYLOAD), 0) # type: ignore # noqa
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()
if tail:
if upgraded:
self.data_received(tail)
else:
self._tail = tail

File diff suppressed because it is too large Load Diff

@ -1,46 +1,19 @@
"""WebSocket client for asyncio."""
import asyncio
from typing import Any, Optional
import json
import sys
import async_timeout
from ._ws_impl import CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType
from .client_exceptions import ClientError
from .client_reqrep import ClientResponse
from .helpers import call_later, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSMessage,
WSMsgType,
)
from .http_websocket import WebSocketWriter # WSMessage
from .streams import EofStream, FlowControlDataQueue # noqa
from .typedefs import (
DEFAULT_JSON_DECODER,
DEFAULT_JSON_ENCODER,
JSONDecoder,
JSONEncoder,
)
PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)
class ClientWebSocketResponse:
def __init__(self,
reader: 'FlowControlDataQueue[WSMessage]',
writer: WebSocketWriter,
protocol: Optional[str],
response: ClientResponse,
timeout: float,
autoclose: bool,
autoping: bool,
loop: asyncio.AbstractEventLoop,
*,
receive_timeout: Optional[float]=None,
heartbeat: Optional[float]=None,
compress: int=0,
client_notakeover: bool=False) -> None:
def __init__(self, reader, writer, protocol,
response, timeout, autoclose, autoping, loop):
self._response = response
self._conn = response.connection
@ -49,128 +22,63 @@ class ClientWebSocketResponse:
self._protocol = protocol
self._closed = False
self._closing = False
self._close_code = None # type: Optional[int]
self._close_code = None
self._timeout = timeout
self._receive_timeout = receive_timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._loop = loop
self._waiting = None # type: Optional[asyncio.Future[bool]]
self._exception = None # type: Optional[BaseException]
self._compress = compress
self._client_notakeover = client_notakeover
self._reset_heartbeat()
def _cancel_heartbeat(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
if self._heartbeat is not None:
self._heartbeat_cb = call_later(
self._send_heartbeat, self._heartbeat, self._loop)
def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping())
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received, self._pong_heartbeat, self._loop)
def _pong_not_received(self) -> None:
if not self._closed:
self._closed = True
self._close_code = 1006
self._exception = asyncio.TimeoutError()
self._response.close()
self._waiting = False
self._exception = None
@property
def closed(self) -> bool:
def closed(self):
return self._closed
@property
def close_code(self) -> Optional[int]:
def close_code(self):
return self._close_code
@property
def protocol(self) -> Optional[str]:
def protocol(self):
return self._protocol
@property
def compress(self) -> int:
return self._compress
@property
def client_notakeover(self) -> bool:
return self._client_notakeover
def get_extra_info(self, name: str, default: Any=None) -> Any:
"""extra info from connection transport"""
conn = self._response.connection
if conn is None:
return default
transport = conn.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def exception(self) -> Optional[BaseException]:
def exception(self):
return self._exception
async def ping(self, message: bytes=b'') -> None:
await self._writer.ping(message)
def ping(self, message='b'):
if self._closed:
raise RuntimeError('websocket connection is closed')
self._writer.ping(message)
async def pong(self, message: bytes=b'') -> None:
await self._writer.pong(message)
def pong(self, message='b'):
if self._closed:
raise RuntimeError('websocket connection is closed')
self._writer.pong(message)
async def send_str(self, data: str,
compress: Optional[int]=None) -> None:
def send_str(self, data):
if self._closed:
raise RuntimeError('websocket connection is closed')
if not isinstance(data, str):
raise TypeError('data argument must be str (%r)' % type(data))
await self._writer.send(data, binary=False, compress=compress)
self._writer.send(data, binary=False)
async def send_bytes(self, data: bytes,
compress: Optional[int]=None) -> None:
def send_bytes(self, data):
if self._closed:
raise RuntimeError('websocket connection is closed')
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)' %
type(data))
await self._writer.send(data, binary=True, compress=compress)
self._writer.send(data, binary=True)
async def send_json(self, data: Any,
compress: Optional[int]=None,
*, dumps: JSONEncoder=DEFAULT_JSON_ENCODER) -> None:
await self.send_str(dumps(data), compress=compress)
async def close(self, *, code: int=1000, message: bytes=b'') -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting
def send_json(self, data, *, dumps=json.dumps):
self.send_str(dumps(data))
@asyncio.coroutine
def close(self, *, code=1000, message=b''):
if not self._closed:
self._cancel_heartbeat()
self._closed = True
try:
await self._writer.close(code, message)
self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close()
@ -187,8 +95,8 @@ class ClientWebSocketResponse:
while True:
try:
with async_timeout.timeout(self._timeout, loop=self._loop):
msg = await self._reader.read()
msg = yield from asyncio.wait_for(
self._reader.read(), self._timeout, loop=self._loop)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close()
@ -206,96 +114,80 @@ class ClientWebSocketResponse:
else:
return False
async def receive(self, timeout: Optional[float]=None) -> WSMessage:
while True:
if self._waiting is not None:
raise RuntimeError(
'Concurrent call to receive() is not allowed')
@asyncio.coroutine
def receive(self):
if self._waiting:
raise RuntimeError('Concurrent call to receive() is not allowed')
if self._closed:
return WS_CLOSED_MESSAGE
elif self._closing:
await self.close()
return WS_CLOSED_MESSAGE
self._waiting = True
try:
while True:
if self._closed:
return CLOSED_MESSAGE
try:
self._waiting = self._loop.create_future()
try:
with async_timeout.timeout(
timeout or self._receive_timeout,
loop=self._loop):
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
self._waiting = None
set_result(waiter, True)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
raise
except EofStream:
self._close_code = 1000
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
self._closed = True
self._close_code = 1006
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type == WSMsgType.PONG and self._autoping:
continue
return msg
msg = yield from self._reader.read()
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except WebSocketError as exc:
self._close_code = exc.code
yield from self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
yield from self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
async def receive_str(self, *, timeout: Optional[float]=None) -> str:
msg = await self.receive(timeout)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
yield from self.close()
return msg
if msg.type == WSMsgType.PING and self._autoping:
self.pong(msg.data)
elif msg.type == WSMsgType.PONG and self._autoping:
continue
else:
return msg
finally:
self._waiting = False
@asyncio.coroutine
def receive_str(self):
msg = yield from self.receive()
if msg.type != WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not str".format(msg.type,
msg.data))
return msg.data
async def receive_bytes(self, *, timeout: Optional[float]=None) -> bytes:
msg = await self.receive(timeout)
@asyncio.coroutine
def receive_bytes(self):
msg = yield from self.receive()
if msg.type != WSMsgType.BINARY:
raise TypeError(
"Received message {}:{!r} is not bytes".format(msg.type,
msg.data))
return msg.data
async def receive_json(self,
*, loads: JSONDecoder=DEFAULT_JSON_DECODER,
timeout: Optional[float]=None) -> Any:
data = await self.receive_str(timeout=timeout)
@asyncio.coroutine
def receive_json(self, *, loads=json.loads):
data = yield from self.receive_str()
return loads(data)
def __aiter__(self) -> 'ClientWebSocketResponse':
return self
if PY_35:
def __aiter__(self):
return self
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
async def __anext__(self) -> WSMessage:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE,
WSMsgType.CLOSING,
WSMsgType.CLOSED):
raise StopAsyncIteration # NOQA
return msg
@asyncio.coroutine
def __anext__(self):
msg = yield from self.receive()
if msg.type == WSMsgType.CLOSE:
raise StopAsyncIteration # NOQA
return msg

File diff suppressed because it is too large Load Diff

@ -1,90 +1,57 @@
import asyncio
import datetime
import os # noqa
import pathlib
import pickle
import re
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie # noqa
from typing import ( # noqa
DefaultDict,
Dict,
Iterable,
Iterator,
Mapping,
Optional,
Set,
Tuple,
Union,
cast,
)
from yarl import URL
from collections.abc import Mapping
from http.cookies import Morsel, SimpleCookie
from math import ceil
from urllib.parse import urlsplit
from .abc import AbstractCookieJar
from .helpers import is_ip_address, next_whole_second
from .typedefs import LooseCookies, PathLike
__all__ = ('CookieJar', 'DummyCookieJar')
CookieItem = Union[str, 'Morsel[str]']
from .helpers import is_ip_address
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
DATE_TOKENS_RE = re.compile(
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)")
"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)")
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
DATE_HMS_TIME_RE = re.compile("(\d{1,2}):(\d{1,2}):(\d{1,2})")
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
DATE_DAY_OF_MONTH_RE = re.compile("(\d{1,2})")
DATE_MONTH_RE = re.compile("(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|"
"(aug)|(sep)|(oct)|(nov)|(dec)", re.I)
DATE_YEAR_RE = re.compile(r"(\d{2,4})")
DATE_YEAR_RE = re.compile("(\d{2,4})")
MAX_TIME = datetime.datetime.max.replace(
tzinfo=datetime.timezone.utc)
MAX_TIME = 2051215261.0 # so far in future (2035-01-01)
def __init__(self, *, unsafe: bool=False,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
def __init__(self, *, unsafe=False, loop=None):
super().__init__(loop=loop)
self._cookies = defaultdict(SimpleCookie) #type: DefaultDict[str, SimpleCookie] # noqa
self._host_only_cookies = set() # type: Set[Tuple[str, str]]
self._cookies = defaultdict(SimpleCookie)
self._host_only_cookies = set()
self._unsafe = unsafe
self._next_expiration = next_whole_second()
self._expirations = {} # type: Dict[Tuple[str, str], datetime.datetime] # noqa: E501
def save(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode='wb') as f:
pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
def load(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode='rb') as f:
self._cookies = pickle.load(f)
self._next_expiration = ceil(self._loop.time())
self._expirations = {}
def clear(self) -> None:
def clear(self):
self._cookies.clear()
self._host_only_cookies.clear()
self._next_expiration = next_whole_second()
self._next_expiration = ceil(self._loop.time())
self._expirations.clear()
def __iter__(self) -> 'Iterator[Morsel[str]]':
def __iter__(self):
self._do_expiration()
for val in self._cookies.values():
yield from val.values()
def __len__(self) -> int:
def __len__(self):
return sum(1 for i in self)
def _do_expiration(self) -> None:
now = datetime.datetime.now(datetime.timezone.utc)
def _do_expiration(self):
now = self._loop.time()
if self._next_expiration > now:
return
if not self._expirations:
@ -94,7 +61,7 @@ class CookieJar(AbstractCookieJar):
cookies = self._cookies
expirations = self._expirations
for (domain, name), when in expirations.items():
if when <= now:
if when < now:
cookies[domain].pop(name, None)
to_del.append((domain, name))
self._host_only_cookies.discard((domain, name))
@ -103,34 +70,28 @@ class CookieJar(AbstractCookieJar):
for key in to_del:
del expirations[key]
try:
self._next_expiration = (next_expiration.replace(microsecond=0) +
datetime.timedelta(seconds=1))
except OverflowError:
self._next_expiration = self.MAX_TIME
self._next_expiration = ceil(next_expiration)
def _expire_cookie(self, when: datetime.datetime, domain: str, name: str
) -> None:
def _expire_cookie(self, when, domain, name):
self._next_expiration = min(self._next_expiration, when)
self._expirations[(domain, name)] = when
def update_cookies(self,
cookies: LooseCookies,
response_url: URL=URL()) -> None:
def update_cookies(self, cookies, response_url=None):
"""Update cookies."""
hostname = response_url.raw_host
url_parsed = urlsplit(response_url or "")
hostname = url_parsed.hostname
if not self._unsafe and is_ip_address(hostname):
# Don't accept cookies from IPs
return
if isinstance(cookies, Mapping):
cookies = cookies.items() # type: ignore
cookies = cookies.items()
for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp = SimpleCookie()
tmp[name] = cookie # type: ignore
tmp[name] = cookie
cookie = tmp[name]
domain = cookie["domain"]
@ -158,7 +119,7 @@ class CookieJar(AbstractCookieJar):
path = cookie["path"]
if not path or not path.startswith("/"):
# Set the cookie's path to the response path
path = response_url.path
path = url_parsed.path
if not path.startswith("/"):
path = "/"
else:
@ -170,13 +131,7 @@ class CookieJar(AbstractCookieJar):
if max_age:
try:
delta_seconds = int(max_age)
try:
max_age_expiration = (
datetime.datetime.now(datetime.timezone.utc) +
datetime.timedelta(seconds=delta_seconds))
except OverflowError:
max_age_expiration = self.MAX_TIME
self._expire_cookie(max_age_expiration,
self._expire_cookie(self._loop.time() + delta_seconds,
domain, name)
except ValueError:
cookie["max-age"] = ""
@ -186,22 +141,24 @@ class CookieJar(AbstractCookieJar):
if expires:
expire_time = self._parse_date(expires)
if expire_time:
self._expire_cookie(expire_time,
self._expire_cookie(expire_time.timestamp(),
domain, name)
else:
cookie["expires"] = ""
self._cookies[domain][name] = cookie
# use dict method because SimpleCookie class modifies value
# before Python 3.4.3
dict.__setitem__(self._cookies[domain], name, cookie)
self._do_expiration()
def filter_cookies(self, request_url: URL=URL()) -> 'BaseCookie[str]':
def filter_cookies(self, request_url):
"""Returns this jar's cookies filtered by their attributes."""
self._do_expiration()
request_url = URL(request_url)
url_parsed = urlsplit(request_url)
filtered = SimpleCookie()
hostname = request_url.raw_host or ""
is_not_secure = request_url.scheme not in ("https", "wss")
hostname = url_parsed.hostname or ""
is_not_secure = url_parsed.scheme not in ("https", "wss")
for cookie in self:
name = cookie.key
@ -221,22 +178,18 @@ class CookieJar(AbstractCookieJar):
elif not self._is_domain_match(domain, hostname):
continue
if not self._is_path_match(request_url.path, cookie["path"]):
if not self._is_path_match(url_parsed.path, cookie["path"]):
continue
if is_not_secure and cookie["secure"]:
continue
# It's critical we use the Morsel so the coded_value
# (based on cookie version) is preserved
mrsl_val = cast('Morsel[str]', cookie.get(cookie.key, Morsel()))
mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
filtered[name] = mrsl_val
filtered[name] = cookie.value
return filtered
@staticmethod
def _is_domain_match(domain: str, hostname: str) -> bool:
def _is_domain_match(domain, hostname):
"""Implements domain matching adhering to RFC 6265."""
if hostname == domain:
return True
@ -252,7 +205,7 @@ class CookieJar(AbstractCookieJar):
return not is_ip_address(hostname)
@staticmethod
def _is_path_match(req_path: str, cookie_path: str) -> bool:
def _is_path_match(req_path, cookie_path):
"""Implements path matching adhering to RFC 6265."""
if not req_path.startswith("/"):
req_path = "/"
@ -271,10 +224,10 @@ class CookieJar(AbstractCookieJar):
return non_matching.startswith("/")
@classmethod
def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
def _parse_date(cls, date_str):
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
return
found_time = False
found_day = False
@ -309,7 +262,6 @@ class CookieJar(AbstractCookieJar):
month_match = cls.DATE_MONTH_RE.match(token)
if month_match:
found_month = True
assert month_match.lastindex is not None
month = month_match.lastindex
continue
@ -325,44 +277,14 @@ class CookieJar(AbstractCookieJar):
year += 2000
if False in (found_day, found_month, found_year, found_time):
return None
return
if not 1 <= day <= 31:
return None
return
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None
return
return datetime.datetime(year, month, day,
hour, minute, second,
tzinfo=datetime.timezone.utc)
class DummyCookieJar(AbstractCookieJar):
"""Implements a dummy cookie storage.
It can be used with the ClientSession when no cookie processing is needed.
"""
def __init__(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
super().__init__(loop=loop)
def __iter__(self) -> 'Iterator[Morsel[str]]':
while False:
yield None
def __len__(self) -> int:
return 0
def clear(self) -> None:
pass
def update_cookies(self,
cookies: LooseCookies,
response_url: URL=URL()) -> None:
pass
def filter_cookies(self, request_url: URL) -> 'BaseCookie[str]':
return SimpleCookie()

@ -0,0 +1,186 @@
"""HTTP related errors."""
from asyncio import TimeoutError
__all__ = (
'DisconnectedError', 'ClientDisconnectedError', 'ServerDisconnectedError',
'HttpProcessingError', 'BadHttpMessage',
'HttpMethodNotAllowed', 'HttpBadRequest', 'HttpProxyError',
'BadStatusLine', 'LineTooLong', 'InvalidHeader',
'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError',
'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError',
'ClientRequestError', 'ClientResponseError',
'FingerprintMismatch',
'WSServerHandshakeError', 'WSClientDisconnectedError')
class DisconnectedError(Exception):
"""Disconnected."""
class ClientDisconnectedError(DisconnectedError):
"""Client disconnected."""
class ServerDisconnectedError(DisconnectedError):
"""Server disconnected."""
class WSClientDisconnectedError(ClientDisconnectedError):
"""Deprecated."""
class ClientError(Exception):
"""Base class for client connection errors."""
class ClientHttpProcessingError(ClientError):
"""Base class for client HTTP processing errors."""
class ClientRequestError(ClientHttpProcessingError):
"""Connection error during sending request."""
class ClientResponseError(ClientHttpProcessingError):
"""Connection error during reading response."""
class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
class ClientTimeoutError(ClientConnectionError, TimeoutError):
"""Client connection timeout error."""
class ProxyConnectionError(ClientConnectionError):
"""Proxy connection error.
Raised in :class:`aiohttp.connector.ProxyConnector` if
connection to proxy can not be established.
"""
class HttpProcessingError(Exception):
"""HTTP error.
Shortcut for raising HTTP errors with custom code, message and headers.
:param int code: HTTP Error code.
:param str message: (optional) Error message.
:param list of [tuple] headers: (optional) Headers to be sent in response.
"""
code = 0
message = ''
headers = None
def __init__(self, *, code=None, message='', headers=None):
if code is not None:
self.code = code
self.headers = headers
self.message = message
super().__init__("%s, message='%s'" % (self.code, message))
class WSServerHandshakeError(HttpProcessingError):
"""websocket server handshake error."""
class HttpProxyError(HttpProcessingError):
"""HTTP proxy error.
Raised in :class:`aiohttp.connector.ProxyConnector` if
proxy responds with status other than ``200 OK``
on ``CONNECT`` request.
"""
class BadHttpMessage(HttpProcessingError):
code = 400
message = 'Bad Request'
def __init__(self, message, *, headers=None):
super().__init__(message=message, headers=headers)
class HttpMethodNotAllowed(HttpProcessingError):
code = 405
message = 'Method Not Allowed'
class HttpBadRequest(BadHttpMessage):
code = 400
message = 'Bad Request'
class ContentEncodingError(BadHttpMessage):
"""Content encoding error."""
class TransferEncodingError(BadHttpMessage):
"""transfer encoding error."""
class LineTooLong(BadHttpMessage):
def __init__(self, line, limit='Unknown'):
super().__init__(
"got more than %s bytes when reading %s" % (limit, line))
class InvalidHeader(BadHttpMessage):
def __init__(self, hdr):
if isinstance(hdr, bytes):
hdr = hdr.decode('utf-8', 'surrogateescape')
super().__init__('Invalid HTTP Header: {}'.format(hdr))
self.hdr = hdr
class BadStatusLine(BadHttpMessage):
def __init__(self, line=''):
if not line:
line = repr(line)
self.args = line,
self.line = line
class LineLimitExceededParserError(HttpBadRequest):
"""Line is too long."""
def __init__(self, msg, limit):
super().__init__(msg)
self.limit = limit
class FingerprintMismatch(ClientConnectionError):
"""SSL certificate does not match expected fingerprint."""
def __init__(self, expected, got, host, port):
self.expected = expected
self.got = got
self.host = host
self.port = port
def __repr__(self):
return '<{} expected={} got={} host={} port={}>'.format(
self.__class__.__name__, self.expected, self.got,
self.host, self.port)
class InvalidURL(Exception):
"""Invalid URL."""

@ -1,150 +0,0 @@
import io
from typing import Any, Iterable, List, Optional # noqa
from urllib.parse import urlencode
from multidict import MultiDict, MultiDictProxy
from . import hdrs, multipart, payload
from .helpers import guess_filename
from .payload import Payload
__all__ = ('FormData',)
class FormData:
"""Helper class for multipart/form-data and
application/x-www-form-urlencoded body generation."""
def __init__(self, fields:
Iterable[Any]=(),
quote_fields: bool=True,
charset: Optional[str]=None) -> None:
self._writer = multipart.MultipartWriter('form-data')
self._fields = [] # type: List[Any]
self._is_multipart = False
self._quote_fields = quote_fields
self._charset = charset
if isinstance(fields, dict):
fields = list(fields.items())
elif not isinstance(fields, (list, tuple)):
fields = (fields,)
self.add_fields(*fields)
@property
def is_multipart(self) -> bool:
return self._is_multipart
def add_field(self, name: str, value: Any, *,
content_type: Optional[str]=None,
filename: Optional[str]=None,
content_transfer_encoding: Optional[str]=None) -> None:
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
if filename is None and content_transfer_encoding is None:
filename = name
type_options = MultiDict({'name': name})
if filename is not None and not isinstance(filename, str):
raise TypeError('filename must be an instance of str. '
'Got: %s' % filename)
if filename is None and isinstance(value, io.IOBase):
filename = guess_filename(value, name)
if filename is not None:
type_options['filename'] = filename
self._is_multipart = True
headers = {}
if content_type is not None:
if not isinstance(content_type, str):
raise TypeError('content_type must be an instance of str. '
'Got: %s' % content_type)
headers[hdrs.CONTENT_TYPE] = content_type
self._is_multipart = True
if content_transfer_encoding is not None:
if not isinstance(content_transfer_encoding, str):
raise TypeError('content_transfer_encoding must be an instance'
' of str. Got: %s' % content_transfer_encoding)
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
self._is_multipart = True
self._fields.append((type_options, headers, value))
def add_fields(self, *fields: Any) -> None:
to_add = list(fields)
while to_add:
rec = to_add.pop(0)
if isinstance(rec, io.IOBase):
k = guess_filename(rec, 'unknown')
self.add_field(k, rec) # type: ignore
elif isinstance(rec, (MultiDictProxy, MultiDict)):
to_add.extend(rec.items())
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
k, fp = rec
self.add_field(k, fp) # type: ignore
else:
raise TypeError('Only io.IOBase, multidict and (name, file) '
'pairs allowed, use .add_field() for passing '
'more complex parameters, got {!r}'
.format(rec))
def _gen_form_urlencoded(self) -> payload.BytesPayload:
# form data (x-www-form-urlencoded)
data = []
for type_options, _, value in self._fields:
data.append((type_options['name'], value))
charset = self._charset if self._charset is not None else 'utf-8'
if charset == 'utf-8':
content_type = 'application/x-www-form-urlencoded'
else:
content_type = ('application/x-www-form-urlencoded; '
'charset=%s' % charset)
return payload.BytesPayload(
urlencode(data, doseq=True, encoding=charset).encode(),
content_type=content_type)
def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
part = payload.get_payload(
value, content_type=headers[hdrs.CONTENT_TYPE],
headers=headers, encoding=self._charset)
else:
part = payload.get_payload(
value, headers=headers, encoding=self._charset)
except Exception as exc:
raise TypeError(
'Can not serialize value type: %r\n '
'headers: %r\n value: %r' % (
type(value), headers, value)) from exc
if dispparams:
part.set_content_disposition(
'form-data', quote_fields=self._quote_fields, **dispparams
)
# FIXME cgi.FieldStorage doesn't likes body parts with
# Content-Length which were sent via chunked transfer encoding
assert part.headers is not None
part.headers.popall(hdrs.CONTENT_LENGTH, None)
self._writer.append_payload(part)
return self._writer
def __call__(self) -> Payload:
if self._is_multipart:
return self._gen_form_data()
else:
return self._gen_form_urlencoded()

@ -1,72 +0,0 @@
from collections.abc import MutableSequence
from functools import total_ordering
from .helpers import NO_EXTENSIONS
@total_ordering
class FrozenList(MutableSequence):
__slots__ = ('_frozen', '_items')
def __init__(self, items=None):
self._frozen = False
if items is not None:
items = list(items)
else:
items = []
self._items = items
@property
def frozen(self):
return self._frozen
def freeze(self):
self._frozen = True
def __getitem__(self, index):
return self._items[index]
def __setitem__(self, index, value):
if self._frozen:
raise RuntimeError("Cannot modify frozen list.")
self._items[index] = value
def __delitem__(self, index):
if self._frozen:
raise RuntimeError("Cannot modify frozen list.")
del self._items[index]
def __len__(self):
return self._items.__len__()
def __iter__(self):
return self._items.__iter__()
def __reversed__(self):
return self._items.__reversed__()
def __eq__(self, other):
return list(self) == other
def __le__(self, other):
return list(self) <= other
def insert(self, pos, item):
if self._frozen:
raise RuntimeError("Cannot modify frozen list.")
self._items.insert(pos, item)
def __repr__(self):
return '<FrozenList(frozen={}, {!r})>'.format(self._frozen,
self._items)
PyFrozenList = FrozenList
try:
from aiohttp._frozenlist import FrozenList as CFrozenList # type: ignore
if not NO_EXTENSIONS:
FrozenList = CFrozenList # type: ignore
except ImportError: # pragma: no cover
pass

@ -1,63 +0,0 @@
from typing import (
Generic,
Iterable,
Iterator,
List,
MutableSequence,
Optional,
TypeVar,
Union,
overload,
)
_T = TypeVar('_T')
_Arg = Union[List[_T], Iterable[_T]]
class FrozenList(MutableSequence[_T], Generic[_T]):
def __init__(self, items: Optional[_Arg[_T]]=...) -> None: ...
@property
def frozen(self) -> bool: ...
def freeze(self) -> None: ...
@overload
def __getitem__(self, i: int) -> _T: ...
@overload
def __getitem__(self, s: slice) -> FrozenList[_T]: ...
@overload
def __setitem__(self, i: int, o: _T) -> None: ...
@overload
def __setitem__(self, s: slice, o: Iterable[_T]) -> None: ...
@overload
def __delitem__(self, i: int) -> None: ...
@overload
def __delitem__(self, i: slice) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __reversed__(self) -> Iterator[_T]: ...
def __eq__(self, other: object) -> bool: ...
def __le__(self, other: FrozenList[_T]) -> bool: ...
def __ne__(self, other: object) -> bool: ...
def __lt__(self, other: FrozenList[_T]) -> bool: ...
def __ge__(self, other: FrozenList[_T]) -> bool: ...
def __gt__(self, other: FrozenList[_T]) -> bool: ...
def insert(self, pos: int, item: _T) -> None: ...
def __repr__(self) -> str: ...
# types for C accelerators are the same
CFrozenList = PyFrozenList = FrozenList

@ -1,8 +1,4 @@
"""HTTP Headers constants."""
# After changing the file content call ./tools/gen.py
# to regenerate the headers parser
from multidict import istr
METH_ANY = '*'
@ -20,81 +16,76 @@ METH_ALL = {METH_CONNECT, METH_HEAD, METH_GET, METH_DELETE,
METH_OPTIONS, METH_PATCH, METH_POST, METH_PUT, METH_TRACE}
ACCEPT = istr('Accept')
ACCEPT_CHARSET = istr('Accept-Charset')
ACCEPT_ENCODING = istr('Accept-Encoding')
ACCEPT_LANGUAGE = istr('Accept-Language')
ACCEPT_RANGES = istr('Accept-Ranges')
ACCESS_CONTROL_MAX_AGE = istr('Access-Control-Max-Age')
ACCESS_CONTROL_ALLOW_CREDENTIALS = istr('Access-Control-Allow-Credentials')
ACCESS_CONTROL_ALLOW_HEADERS = istr('Access-Control-Allow-Headers')
ACCESS_CONTROL_ALLOW_METHODS = istr('Access-Control-Allow-Methods')
ACCESS_CONTROL_ALLOW_ORIGIN = istr('Access-Control-Allow-Origin')
ACCESS_CONTROL_EXPOSE_HEADERS = istr('Access-Control-Expose-Headers')
ACCESS_CONTROL_REQUEST_HEADERS = istr('Access-Control-Request-Headers')
ACCESS_CONTROL_REQUEST_METHOD = istr('Access-Control-Request-Method')
AGE = istr('Age')
ALLOW = istr('Allow')
AUTHORIZATION = istr('Authorization')
CACHE_CONTROL = istr('Cache-Control')
CONNECTION = istr('Connection')
CONTENT_DISPOSITION = istr('Content-Disposition')
CONTENT_ENCODING = istr('Content-Encoding')
CONTENT_LANGUAGE = istr('Content-Language')
CONTENT_LENGTH = istr('Content-Length')
CONTENT_LOCATION = istr('Content-Location')
CONTENT_MD5 = istr('Content-MD5')
CONTENT_RANGE = istr('Content-Range')
CONTENT_TRANSFER_ENCODING = istr('Content-Transfer-Encoding')
CONTENT_TYPE = istr('Content-Type')
COOKIE = istr('Cookie')
DATE = istr('Date')
DESTINATION = istr('Destination')
DIGEST = istr('Digest')
ETAG = istr('Etag')
EXPECT = istr('Expect')
EXPIRES = istr('Expires')
FORWARDED = istr('Forwarded')
FROM = istr('From')
HOST = istr('Host')
IF_MATCH = istr('If-Match')
IF_MODIFIED_SINCE = istr('If-Modified-Since')
IF_NONE_MATCH = istr('If-None-Match')
IF_RANGE = istr('If-Range')
IF_UNMODIFIED_SINCE = istr('If-Unmodified-Since')
KEEP_ALIVE = istr('Keep-Alive')
LAST_EVENT_ID = istr('Last-Event-ID')
LAST_MODIFIED = istr('Last-Modified')
LINK = istr('Link')
LOCATION = istr('Location')
MAX_FORWARDS = istr('Max-Forwards')
ORIGIN = istr('Origin')
PRAGMA = istr('Pragma')
PROXY_AUTHENTICATE = istr('Proxy-Authenticate')
PROXY_AUTHORIZATION = istr('Proxy-Authorization')
RANGE = istr('Range')
REFERER = istr('Referer')
RETRY_AFTER = istr('Retry-After')
SEC_WEBSOCKET_ACCEPT = istr('Sec-WebSocket-Accept')
SEC_WEBSOCKET_VERSION = istr('Sec-WebSocket-Version')
SEC_WEBSOCKET_PROTOCOL = istr('Sec-WebSocket-Protocol')
SEC_WEBSOCKET_EXTENSIONS = istr('Sec-WebSocket-Extensions')
SEC_WEBSOCKET_KEY = istr('Sec-WebSocket-Key')
SEC_WEBSOCKET_KEY1 = istr('Sec-WebSocket-Key1')
SERVER = istr('Server')
SET_COOKIE = istr('Set-Cookie')
ACCEPT = istr('ACCEPT')
ACCEPT_CHARSET = istr('ACCEPT-CHARSET')
ACCEPT_ENCODING = istr('ACCEPT-ENCODING')
ACCEPT_LANGUAGE = istr('ACCEPT-LANGUAGE')
ACCEPT_RANGES = istr('ACCEPT-RANGES')
ACCESS_CONTROL_MAX_AGE = istr('ACCESS-CONTROL-MAX-AGE')
ACCESS_CONTROL_ALLOW_CREDENTIALS = istr('ACCESS-CONTROL-ALLOW-CREDENTIALS')
ACCESS_CONTROL_ALLOW_HEADERS = istr('ACCESS-CONTROL-ALLOW-HEADERS')
ACCESS_CONTROL_ALLOW_METHODS = istr('ACCESS-CONTROL-ALLOW-METHODS')
ACCESS_CONTROL_ALLOW_ORIGIN = istr('ACCESS-CONTROL-ALLOW-ORIGIN')
ACCESS_CONTROL_EXPOSE_HEADERS = istr('ACCESS-CONTROL-EXPOSE-HEADERS')
ACCESS_CONTROL_REQUEST_HEADERS = istr('ACCESS-CONTROL-REQUEST-HEADERS')
ACCESS_CONTROL_REQUEST_METHOD = istr('ACCESS-CONTROL-REQUEST-METHOD')
AGE = istr('AGE')
ALLOW = istr('ALLOW')
AUTHORIZATION = istr('AUTHORIZATION')
CACHE_CONTROL = istr('CACHE-CONTROL')
CONNECTION = istr('CONNECTION')
CONTENT_DISPOSITION = istr('CONTENT-DISPOSITION')
CONTENT_ENCODING = istr('CONTENT-ENCODING')
CONTENT_LANGUAGE = istr('CONTENT-LANGUAGE')
CONTENT_LENGTH = istr('CONTENT-LENGTH')
CONTENT_LOCATION = istr('CONTENT-LOCATION')
CONTENT_MD5 = istr('CONTENT-MD5')
CONTENT_RANGE = istr('CONTENT-RANGE')
CONTENT_TRANSFER_ENCODING = istr('CONTENT-TRANSFER-ENCODING')
CONTENT_TYPE = istr('CONTENT-TYPE')
COOKIE = istr('COOKIE')
DATE = istr('DATE')
DESTINATION = istr('DESTINATION')
DIGEST = istr('DIGEST')
ETAG = istr('ETAG')
EXPECT = istr('EXPECT')
EXPIRES = istr('EXPIRES')
FROM = istr('FROM')
HOST = istr('HOST')
IF_MATCH = istr('IF-MATCH')
IF_MODIFIED_SINCE = istr('IF-MODIFIED-SINCE')
IF_NONE_MATCH = istr('IF-NONE-MATCH')
IF_RANGE = istr('IF-RANGE')
IF_UNMODIFIED_SINCE = istr('IF-UNMODIFIED-SINCE')
KEEP_ALIVE = istr('KEEP-ALIVE')
LAST_EVENT_ID = istr('LAST-EVENT-ID')
LAST_MODIFIED = istr('LAST-MODIFIED')
LINK = istr('LINK')
LOCATION = istr('LOCATION')
MAX_FORWARDS = istr('MAX-FORWARDS')
ORIGIN = istr('ORIGIN')
PRAGMA = istr('PRAGMA')
PROXY_AUTHENTICATE = istr('PROXY_AUTHENTICATE')
PROXY_AUTHORIZATION = istr('PROXY-AUTHORIZATION')
RANGE = istr('RANGE')
REFERER = istr('REFERER')
RETRY_AFTER = istr('RETRY-AFTER')
SEC_WEBSOCKET_ACCEPT = istr('SEC-WEBSOCKET-ACCEPT')
SEC_WEBSOCKET_VERSION = istr('SEC-WEBSOCKET-VERSION')
SEC_WEBSOCKET_PROTOCOL = istr('SEC-WEBSOCKET-PROTOCOL')
SEC_WEBSOCKET_KEY = istr('SEC-WEBSOCKET-KEY')
SEC_WEBSOCKET_KEY1 = istr('SEC-WEBSOCKET-KEY1')
SERVER = istr('SERVER')
SET_COOKIE = istr('SET-COOKIE')
TE = istr('TE')
TRAILER = istr('Trailer')
TRANSFER_ENCODING = istr('Transfer-Encoding')
UPGRADE = istr('Upgrade')
WEBSOCKET = istr('WebSocket')
TRAILER = istr('TRAILER')
TRANSFER_ENCODING = istr('TRANSFER-ENCODING')
UPGRADE = istr('UPGRADE')
WEBSOCKET = istr('WEBSOCKET')
URI = istr('URI')
USER_AGENT = istr('User-Agent')
VARY = istr('Vary')
VIA = istr('Via')
WANT_DIGEST = istr('Want-Digest')
WARNING = istr('Warning')
WWW_AUTHENTICATE = istr('WWW-Authenticate')
X_FORWARDED_FOR = istr('X-Forwarded-For')
X_FORWARDED_HOST = istr('X-Forwarded-Host')
X_FORWARDED_PROTO = istr('X-Forwarded-Proto')
USER_AGENT = istr('USER-AGENT')
VARY = istr('VARY')
VIA = istr('VIA')
WANT_DIGEST = istr('WANT-DIGEST')
WARNING = istr('WARNING')
WWW_AUTHENTICATE = istr('WWW-AUTHENTICATE')

File diff suppressed because it is too large Load Diff

@ -1,50 +0,0 @@
import http.server
import sys
from typing import Mapping, Tuple # noqa
from . import __version__
from .http_exceptions import HttpProcessingError as HttpProcessingError
from .http_parser import HeadersParser as HeadersParser
from .http_parser import HttpParser as HttpParser
from .http_parser import HttpRequestParser as HttpRequestParser
from .http_parser import HttpResponseParser as HttpResponseParser
from .http_parser import RawRequestMessage as RawRequestMessage
from .http_parser import RawResponseMessage as RawResponseMessage
from .http_websocket import WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE
from .http_websocket import WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE
from .http_websocket import WS_KEY as WS_KEY
from .http_websocket import WebSocketError as WebSocketError
from .http_websocket import WebSocketReader as WebSocketReader
from .http_websocket import WebSocketWriter as WebSocketWriter
from .http_websocket import WSCloseCode as WSCloseCode
from .http_websocket import WSMessage as WSMessage
from .http_websocket import WSMsgType as WSMsgType
from .http_websocket import ws_ext_gen as ws_ext_gen
from .http_websocket import ws_ext_parse as ws_ext_parse
from .http_writer import HttpVersion as HttpVersion
from .http_writer import HttpVersion10 as HttpVersion10
from .http_writer import HttpVersion11 as HttpVersion11
from .http_writer import StreamWriter as StreamWriter
__all__ = (
'HttpProcessingError', 'RESPONSES', 'SERVER_SOFTWARE',
# .http_writer
'StreamWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
# .http_parser
'HeadersParser', 'HttpParser',
'HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage',
# .http_websocket
'WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
'WebSocketReader', 'WebSocketWriter', 'ws_ext_gen', 'ws_ext_parse',
'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode',
)
SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
sys.version_info, __version__) # type: str
RESPONSES = http.server.BaseHTTPRequestHandler.responses # type: Mapping[int, Tuple[str, str]] # noqa

@ -1,108 +0,0 @@
"""Low-level http related exceptions."""
from typing import Optional, Union
from .typedefs import _CIMultiDict
__all__ = ('HttpProcessingError',)
class HttpProcessingError(Exception):
"""HTTP error.
Shortcut for raising HTTP errors with custom code, message and headers.
code: HTTP Error code.
message: (optional) Error message.
headers: (optional) Headers to be sent in response, a list of pairs
"""
code = 0
message = ''
headers = None
def __init__(self, *,
code: Optional[int]=None,
message: str='',
headers: Optional[_CIMultiDict]=None) -> None:
if code is not None:
self.code = code
self.headers = headers
self.message = message
def __str__(self) -> str:
return "%s, message=%r" % (self.code, self.message)
def __repr__(self) -> str:
return "<%s: %s>" % (self.__class__.__name__, self)
class BadHttpMessage(HttpProcessingError):
code = 400
message = 'Bad Request'
def __init__(self, message: str, *,
headers: Optional[_CIMultiDict]=None) -> None:
super().__init__(message=message, headers=headers)
self.args = (message,)
class HttpBadRequest(BadHttpMessage):
code = 400
message = 'Bad Request'
class PayloadEncodingError(BadHttpMessage):
"""Base class for payload errors"""
class ContentEncodingError(PayloadEncodingError):
"""Content encoding error."""
class TransferEncodingError(PayloadEncodingError):
"""transfer encoding error."""
class ContentLengthError(PayloadEncodingError):
"""Not enough data for satisfy content length header."""
class LineTooLong(BadHttpMessage):
def __init__(self, line: str,
limit: str='Unknown',
actual_size: str='Unknown') -> None:
super().__init__(
"Got more than %s bytes (%s) when reading %s." % (
limit, actual_size, line))
self.args = (line, limit, actual_size)
class InvalidHeader(BadHttpMessage):
def __init__(self, hdr: Union[bytes, str]) -> None:
if isinstance(hdr, bytes):
hdr = hdr.decode('utf-8', 'surrogateescape')
super().__init__('Invalid HTTP Header: {}'.format(hdr))
self.hdr = hdr
self.args = (hdr,)
class BadStatusLine(BadHttpMessage):
def __init__(self, line: str='') -> None:
if not isinstance(line, str):
line = repr(line)
self.args = (line,)
self.line = line
__str__ = Exception.__str__
__repr__ = Exception.__repr__
class InvalidURLError(BadHttpMessage):
pass

@ -1,764 +0,0 @@
import abc
import asyncio
import collections
import re
import string
import zlib
from enum import IntEnum
from typing import Any, List, Optional, Tuple, Type, Union # noqa
from multidict import CIMultiDict, CIMultiDictProxy, istr
from yarl import URL
from . import hdrs
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS, BaseTimerContext
from .http_exceptions import (
BadStatusLine,
ContentEncodingError,
ContentLengthError,
InvalidHeader,
LineTooLong,
TransferEncodingError,
)
from .http_writer import HttpVersion, HttpVersion10
from .log import internal_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import RawHeaders
try:
import brotli
HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False
__all__ = (
'HeadersParser', 'HttpParser', 'HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage')
ASCIISET = set(string.printable)
# See https://tools.ietf.org/html/rfc7230#section-3.1.1
# and https://tools.ietf.org/html/rfc7230#appendix-B
#
# method = token
# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
# token = 1*tchar
METHRE = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
VERSRE = re.compile(r'HTTP/(\d+).(\d+)')
HDRRE = re.compile(rb'[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]')
RawRequestMessage = collections.namedtuple(
'RawRequestMessage',
['method', 'path', 'version', 'headers', 'raw_headers',
'should_close', 'compression', 'upgrade', 'chunked', 'url'])
RawResponseMessage = collections.namedtuple(
'RawResponseMessage',
['version', 'code', 'reason', 'headers', 'raw_headers',
'should_close', 'compression', 'upgrade', 'chunked'])
class ParseState(IntEnum):
PARSE_NONE = 0
PARSE_LENGTH = 1
PARSE_CHUNKED = 2
PARSE_UNTIL_EOF = 3
class ChunkState(IntEnum):
PARSE_CHUNKED_SIZE = 0
PARSE_CHUNKED_CHUNK = 1
PARSE_CHUNKED_CHUNK_EOF = 2
PARSE_MAYBE_TRAILERS = 3
PARSE_TRAILERS = 4
class HeadersParser:
def __init__(self,
max_line_size: int=8190,
max_headers: int=32768,
max_field_size: int=8190) -> None:
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size
def parse_headers(
self,
lines: List[bytes]
) -> Tuple['CIMultiDictProxy[str]', RawHeaders]:
headers = CIMultiDict() # type: CIMultiDict[str]
raw_headers = []
lines_idx = 1
line = lines[1]
line_count = len(lines)
while line:
# Parse initial header name : value pair.
try:
bname, bvalue = line.split(b':', 1)
except ValueError:
raise InvalidHeader(line) from None
bname = bname.strip(b' \t')
bvalue = bvalue.lstrip()
if HDRRE.search(bname):
raise InvalidHeader(bname)
if len(bname) > self.max_field_size:
raise LineTooLong(
"request header name {}".format(
bname.decode("utf8", "xmlcharrefreplace")),
str(self.max_field_size),
str(len(bname)))
header_length = len(bvalue)
# next line
lines_idx += 1
line = lines[lines_idx]
# consume continuation lines
continuation = line and line[0] in (32, 9) # (' ', '\t')
if continuation:
bvalue_lst = [bvalue]
while continuation:
header_length += len(line)
if header_length > self.max_field_size:
raise LineTooLong(
'request header field {}'.format(
bname.decode("utf8", "xmlcharrefreplace")),
str(self.max_field_size),
str(header_length))
bvalue_lst.append(line)
# next line
lines_idx += 1
if lines_idx < line_count:
line = lines[lines_idx]
if line:
continuation = line[0] in (32, 9) # (' ', '\t')
else:
line = b''
break
bvalue = b''.join(bvalue_lst)
else:
if header_length > self.max_field_size:
raise LineTooLong(
'request header field {}'.format(
bname.decode("utf8", "xmlcharrefreplace")),
str(self.max_field_size),
str(header_length))
bvalue = bvalue.strip()
name = bname.decode('utf-8', 'surrogateescape')
value = bvalue.decode('utf-8', 'surrogateescape')
headers.add(name, value)
raw_headers.append((bname, bvalue))
return (CIMultiDictProxy(headers), tuple(raw_headers))
class HttpParser(abc.ABC):
def __init__(self, protocol: Optional[BaseProtocol]=None,
loop: Optional[asyncio.AbstractEventLoop]=None,
max_line_size: int=8190,
max_headers: int=32768,
max_field_size: int=8190,
timer: Optional[BaseTimerContext]=None,
code: Optional[int]=None,
method: Optional[str]=None,
readall: bool=False,
payload_exception: Optional[Type[BaseException]]=None,
response_with_body: bool=True,
read_until_eof: bool=False,
auto_decompress: bool=True) -> None:
self.protocol = protocol
self.loop = loop
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size
self.timer = timer
self.code = code
self.method = method
self.readall = readall
self.payload_exception = payload_exception
self.response_with_body = response_with_body
self.read_until_eof = read_until_eof
self._lines = [] # type: List[bytes]
self._tail = b''
self._upgraded = False
self._payload = None
self._payload_parser = None # type: Optional[HttpPayloadParser]
self._auto_decompress = auto_decompress
self._headers_parser = HeadersParser(max_line_size,
max_headers,
max_field_size)
@abc.abstractmethod
def parse_message(self, lines: List[bytes]) -> Any:
pass
def feed_eof(self) -> Any:
if self._payload_parser is not None:
self._payload_parser.feed_eof()
self._payload_parser = None
else:
# try to extract partial message
if self._tail:
self._lines.append(self._tail)
if self._lines:
if self._lines[-1] != '\r\n':
self._lines.append(b'')
try:
return self.parse_message(self._lines)
except Exception:
return None
def feed_data(
self,
data: bytes,
SEP: bytes=b'\r\n',
EMPTY: bytes=b'',
CONTENT_LENGTH: istr=hdrs.CONTENT_LENGTH,
METH_CONNECT: str=hdrs.METH_CONNECT,
SEC_WEBSOCKET_KEY1: istr=hdrs.SEC_WEBSOCKET_KEY1
) -> Tuple[List[Any], bool, bytes]:
messages = []
if self._tail:
data, self._tail = self._tail + data, b''
data_len = len(data)
start_pos = 0
loop = self.loop
while start_pos < data_len:
# read HTTP message (request/response line + headers), \r\n\r\n
# and split by lines
if self._payload_parser is None and not self._upgraded:
pos = data.find(SEP, start_pos)
# consume \r\n
if pos == start_pos and not self._lines:
start_pos = pos + 2
continue
if pos >= start_pos:
# line found
self._lines.append(data[start_pos:pos])
start_pos = pos + 2
# \r\n\r\n found
if self._lines[-1] == EMPTY:
try:
msg = self.parse_message(self._lines)
finally:
self._lines.clear()
# payload length
length = msg.headers.get(CONTENT_LENGTH)
if length is not None:
try:
length = int(length)
except ValueError:
raise InvalidHeader(CONTENT_LENGTH)
if length < 0:
raise InvalidHeader(CONTENT_LENGTH)
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in msg.headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
self._upgraded = msg.upgrade
method = getattr(msg, 'method', self.method)
assert self.protocol is not None
# calculate payload
if ((length is not None and length > 0) or
msg.chunked and not msg.upgrade):
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
payload_parser = HttpPayloadParser(
payload, length=length,
chunked=msg.chunked, method=method,
compression=msg.compression,
code=self.code, readall=self.readall,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress)
if not payload_parser.done:
self._payload_parser = payload_parser
elif method == METH_CONNECT:
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
self._upgraded = True
self._payload_parser = HttpPayloadParser(
payload, method=msg.method,
compression=msg.compression, readall=True,
auto_decompress=self._auto_decompress)
else:
if (getattr(msg, 'code', 100) >= 199 and
length is None and self.read_until_eof):
payload = StreamReader(
self.protocol, timer=self.timer, loop=loop)
payload_parser = HttpPayloadParser(
payload, length=length,
chunked=msg.chunked, method=method,
compression=msg.compression,
code=self.code, readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD # type: ignore
messages.append((msg, payload))
else:
self._tail = data[start_pos:]
data = EMPTY
break
# no parser, just store
elif self._payload_parser is None and self._upgraded:
assert not self._lines
break
# feed payload
elif data and start_pos < data_len:
assert not self._lines
assert self._payload_parser is not None
try:
eof, data = self._payload_parser.feed_data(
data[start_pos:])
except BaseException as exc:
if self.payload_exception is not None:
self._payload_parser.payload.set_exception(
self.payload_exception(str(exc)))
else:
self._payload_parser.payload.set_exception(exc)
eof = True
data = b''
if eof:
start_pos = 0
data_len = len(data)
self._payload_parser = None
continue
else:
break
if data and start_pos < data_len:
data = data[start_pos:]
else:
data = EMPTY
return messages, self._upgraded, data
def parse_headers(
self,
lines: List[bytes]
) -> Tuple['CIMultiDictProxy[str]',
RawHeaders,
Optional[bool],
Optional[str],
bool,
bool]:
"""Parses RFC 5322 headers from a stream.
Line continuations are supported. Returns list of header name
and value pairs. Header name is in upper case.
"""
headers, raw_headers = self._headers_parser.parse_headers(lines)
close_conn = None
encoding = None
upgrade = False
chunked = False
# keep-alive
conn = headers.get(hdrs.CONNECTION)
if conn:
v = conn.lower()
if v == 'close':
close_conn = True
elif v == 'keep-alive':
close_conn = False
elif v == 'upgrade':
upgrade = True
# encoding
enc = headers.get(hdrs.CONTENT_ENCODING)
if enc:
enc = enc.lower()
if enc in ('gzip', 'deflate', 'br'):
encoding = enc
# chunking
te = headers.get(hdrs.TRANSFER_ENCODING)
if te and 'chunked' in te.lower():
chunked = True
return (headers, raw_headers, close_conn, encoding, upgrade, chunked)
class HttpRequestParser(HttpParser):
"""Read request status line. Exception .http_exceptions.BadStatusLine
could be raised in case of any errors in status line.
Returns RawRequestMessage.
"""
def parse_message(self, lines: List[bytes]) -> Any:
# request line
line = lines[0].decode('utf-8', 'surrogateescape')
try:
method, path, version = line.split(None, 2)
except ValueError:
raise BadStatusLine(line) from None
if len(path) > self.max_line_size:
raise LineTooLong(
'Status line is too long',
str(self.max_line_size),
str(len(path)))
# method
if not METHRE.match(method):
raise BadStatusLine(method)
# version
try:
if version.startswith('HTTP/'):
n1, n2 = version[5:].split('.', 1)
version_o = HttpVersion(int(n1), int(n2))
else:
raise BadStatusLine(version)
except Exception:
raise BadStatusLine(version)
# read headers
(headers, raw_headers,
close, compression, upgrade, chunked) = self.parse_headers(lines)
if close is None: # then the headers weren't set in the request
if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close
close = True
else: # HTTP 1.1 must ask to close.
close = False
return RawRequestMessage(
method, path, version_o, headers, raw_headers,
close, compression, upgrade, chunked, URL(path))
class HttpResponseParser(HttpParser):
"""Read response status line and headers.
BadStatusLine could be raised in case of any errors in status line.
Returns RawResponseMessage"""
def parse_message(self, lines: List[bytes]) -> Any:
line = lines[0].decode('utf-8', 'surrogateescape')
try:
version, status = line.split(None, 1)
except ValueError:
raise BadStatusLine(line) from None
try:
status, reason = status.split(None, 1)
except ValueError:
reason = ''
if len(reason) > self.max_line_size:
raise LineTooLong(
'Status line is too long',
str(self.max_line_size),
str(len(reason)))
# version
match = VERSRE.match(version)
if match is None:
raise BadStatusLine(line)
version_o = HttpVersion(int(match.group(1)), int(match.group(2)))
# The status code is a three-digit number
try:
status_i = int(status)
except ValueError:
raise BadStatusLine(line) from None
if status_i > 999:
raise BadStatusLine(line)
# read headers
(headers, raw_headers,
close, compression, upgrade, chunked) = self.parse_headers(lines)
if close is None:
close = version_o <= HttpVersion10
return RawResponseMessage(
version_o, status_i, reason.strip(),
headers, raw_headers, close, compression, upgrade, chunked)
class HttpPayloadParser:
def __init__(self, payload: StreamReader,
length: Optional[int]=None,
chunked: bool=False,
compression: Optional[str]=None,
code: Optional[int]=None,
method: Optional[str]=None,
readall: bool=False,
response_with_body: bool=True,
auto_decompress: bool=True) -> None:
self._length = 0
self._type = ParseState.PARSE_NONE
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
self._chunk_size = 0
self._chunk_tail = b''
self._auto_decompress = auto_decompress
self.done = False
# payload decompression wrapper
if response_with_body and compression and self._auto_decompress:
real_payload = DeflateBuffer(payload, compression) # type: Union[StreamReader, DeflateBuffer] # noqa
else:
real_payload = payload
# payload parser
if not response_with_body:
# don't parse payload if it's not expected to be received
self._type = ParseState.PARSE_NONE
real_payload.feed_eof()
self.done = True
elif chunked:
self._type = ParseState.PARSE_CHUNKED
elif length is not None:
self._type = ParseState.PARSE_LENGTH
self._length = length
if self._length == 0:
real_payload.feed_eof()
self.done = True
else:
if readall and code != 204:
self._type = ParseState.PARSE_UNTIL_EOF
elif method in ('PUT', 'POST'):
internal_logger.warning( # pragma: no cover
'Content-Length or Transfer-Encoding header is required')
self._type = ParseState.PARSE_NONE
real_payload.feed_eof()
self.done = True
self.payload = real_payload
def feed_eof(self) -> None:
if self._type == ParseState.PARSE_UNTIL_EOF:
self.payload.feed_eof()
elif self._type == ParseState.PARSE_LENGTH:
raise ContentLengthError(
"Not enough data for satisfy content length header.")
elif self._type == ParseState.PARSE_CHUNKED:
raise TransferEncodingError(
"Not enough data for satisfy transfer length header.")
def feed_data(self,
chunk: bytes,
SEP: bytes=b'\r\n',
CHUNK_EXT: bytes=b';') -> Tuple[bool, bytes]:
# Read specified amount of bytes
if self._type == ParseState.PARSE_LENGTH:
required = self._length
chunk_len = len(chunk)
if required >= chunk_len:
self._length = required - chunk_len
self.payload.feed_data(chunk, chunk_len)
if self._length == 0:
self.payload.feed_eof()
return True, b''
else:
self._length = 0
self.payload.feed_data(chunk[:required], required)
self.payload.feed_eof()
return True, chunk[required:]
# Chunked transfer encoding parser
elif self._type == ParseState.PARSE_CHUNKED:
if self._chunk_tail:
chunk = self._chunk_tail + chunk
self._chunk_tail = b''
while chunk:
# read next chunk size
if self._chunk == ChunkState.PARSE_CHUNKED_SIZE:
pos = chunk.find(SEP)
if pos >= 0:
i = chunk.find(CHUNK_EXT, 0, pos)
if i >= 0:
size_b = chunk[:i] # strip chunk-extensions
else:
size_b = chunk[:pos]
try:
size = int(bytes(size_b), 16)
except ValueError:
exc = TransferEncodingError(
chunk[:pos].decode('ascii', 'surrogateescape'))
self.payload.set_exception(exc)
raise exc from None
chunk = chunk[pos+2:]
if size == 0: # eof marker
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
else:
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK
self._chunk_size = size
self.payload.begin_http_chunk_receiving()
else:
self._chunk_tail = chunk
return False, b''
# read chunk and feed buffer
if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK:
required = self._chunk_size
chunk_len = len(chunk)
if required > chunk_len:
self._chunk_size = required - chunk_len
self.payload.feed_data(chunk, chunk_len)
return False, b''
else:
self._chunk_size = 0
self.payload.feed_data(chunk[:required], required)
chunk = chunk[required:]
self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF
self.payload.end_http_chunk_receiving()
# toss the CRLF at the end of the chunk
if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF:
if chunk[:2] == SEP:
chunk = chunk[2:]
self._chunk = ChunkState.PARSE_CHUNKED_SIZE
else:
self._chunk_tail = chunk
return False, b''
# if stream does not contain trailer, after 0\r\n
# we should get another \r\n otherwise
# trailers needs to be skiped until \r\n\r\n
if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS:
if chunk[:2] == SEP:
# end of stream
self.payload.feed_eof()
return True, chunk[2:]
else:
self._chunk = ChunkState.PARSE_TRAILERS
# read and discard trailer up to the CRLF terminator
if self._chunk == ChunkState.PARSE_TRAILERS:
pos = chunk.find(SEP)
if pos >= 0:
chunk = chunk[pos+2:]
self._chunk = ChunkState.PARSE_MAYBE_TRAILERS
else:
self._chunk_tail = chunk
return False, b''
# Read all bytes until eof
elif self._type == ParseState.PARSE_UNTIL_EOF:
self.payload.feed_data(chunk, len(chunk))
return False, b''
class DeflateBuffer:
"""DeflateStream decompress stream and feed data into specified stream."""
def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
self.out = out
self.size = 0
self.encoding = encoding
self._started_decoding = False
if encoding == 'br':
if not HAS_BROTLI: # pragma: no cover
raise ContentEncodingError(
'Can not decode content-encoding: brotli (br). '
'Please install `brotlipy`')
self.decompressor = brotli.Decompressor()
else:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
self.decompressor = zlib.decompressobj(wbits=zlib_mode)
def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)
def feed_data(self, chunk: bytes, size: int) -> None:
self.size += size
try:
chunk = self.decompressor.decompress(chunk)
except Exception:
if not self._started_decoding and self.encoding == 'deflate':
self.decompressor = zlib.decompressobj()
try:
chunk = self.decompressor.decompress(chunk)
except Exception:
raise ContentEncodingError(
'Can not decode content-encoding: %s' % self.encoding)
else:
raise ContentEncodingError(
'Can not decode content-encoding: %s' % self.encoding)
if chunk:
self._started_decoding = True
self.out.feed_data(chunk, len(chunk))
def feed_eof(self) -> None:
chunk = self.decompressor.flush()
if chunk or self.size > 0:
self.out.feed_data(chunk, len(chunk))
if self.encoding == 'deflate' and not self.decompressor.eof:
raise ContentEncodingError('deflate')
self.out.feed_eof()
def begin_http_chunk_receiving(self) -> None:
self.out.begin_http_chunk_receiving()
def end_http_chunk_receiving(self) -> None:
self.out.end_http_chunk_receiving()
HttpRequestParserPy = HttpRequestParser
HttpResponseParserPy = HttpResponseParser
RawRequestMessagePy = RawRequestMessage
RawResponseMessagePy = RawResponseMessage
try:
if not NO_EXTENSIONS:
from ._http_parser import (HttpRequestParser, # type: ignore # noqa
HttpResponseParser,
RawRequestMessage,
RawResponseMessage)
HttpRequestParserC = HttpRequestParser
HttpResponseParserC = HttpResponseParser
RawRequestMessageC = RawRequestMessage
RawResponseMessageC = RawResponseMessage
except ImportError: # pragma: no cover
pass

@ -1,659 +0,0 @@
"""WebSocket protocol versions 13 and 8."""
import asyncio
import collections
import json
import random
import re
import sys
import zlib
from enum import IntEnum
from struct import Struct
from typing import Any, Callable, List, Optional, Tuple, Union
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
from .log import ws_logger
from .streams import DataQueue
__all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
'WebSocketReader', 'WebSocketWriter', 'WSMessage',
'WebSocketError', 'WSMsgType', 'WSCloseCode')
class WSCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
INVALID_TEXT = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
class WSMsgType(IntEnum):
# websocket spec types
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
PING = 0x9
PONG = 0xa
CLOSE = 0x8
# aiohttp specific types
CLOSING = 0x100
CLOSED = 0x101
ERROR = 0x102
text = TEXT
binary = BINARY
ping = PING
pong = PONG
close = CLOSE
closing = CLOSING
closed = CLOSED
error = ERROR
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
UNPACK_LEN2 = Struct('!H').unpack_from
UNPACK_LEN3 = Struct('!Q').unpack_from
UNPACK_CLOSE_CODE = Struct('!H').unpack
PACK_LEN1 = Struct('!BB').pack
PACK_LEN2 = Struct('!BBH').pack
PACK_LEN3 = Struct('!BBQ').pack
PACK_CLOSE_CODE = Struct('!H').pack
MSG_SIZE = 2 ** 14
DEFAULT_LIMIT = 2 ** 16
_WSMessageBase = collections.namedtuple('_WSMessageBase',
['type', 'data', 'extra'])
class WSMessage(_WSMessageBase):
def json(self, *,
loads: Callable[[Any], Any]=json.loads) -> Any:
"""Return parsed JSON data.
.. versionadded:: 0.22
"""
return loads(self.data)
WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
class WebSocketError(Exception):
"""WebSocket protocol parser error."""
def __init__(self, code: int, message: str) -> None:
self.code = code
super().__init__(code, message)
def __str__(self) -> str:
return self.args[1]
class WSHandshakeError(Exception):
"""WebSocket protocol handshake error."""
native_byteorder = sys.byteorder
# Used by _websocket_mask_python
_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
object of any length. The contents of `data` are masked with `mask`,
as specified in section 5.3 of RFC 6455.
Note that this function mutates the `data` argument.
This pure-python implementation may be replaced by an optimized
version when available.
"""
assert isinstance(data, bytearray), data
assert len(mask) == 4, mask
if data:
a, b, c, d = (_XOR_TABLE[n] for n in mask)
data[::4] = data[::4].translate(a)
data[1::4] = data[1::4].translate(b)
data[2::4] = data[2::4].translate(c)
data[3::4] = data[3::4].translate(d)
if NO_EXTENSIONS: # pragma: no cover
_websocket_mask = _websocket_mask_python
else:
try:
from ._websocket import _websocket_mask_cython # type: ignore
_websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
_websocket_mask = _websocket_mask_python
_WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xff, 0xff])
_WS_EXT_RE = re.compile(r'^(?:;\s*(?:'
r'(server_no_context_takeover)|'
r'(client_no_context_takeover)|'
r'(server_max_window_bits(?:=(\d+))?)|'
r'(client_max_window_bits(?:=(\d+))?)))*$')
_WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?')
def ws_ext_parse(extstr: str, isserver: bool=False) -> Tuple[int, bool]:
if not extstr:
return 0, False
compress = 0
notakeover = False
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
defext = ext.group(1)
# Return compress = 15 when get `permessage-deflate`
if not defext:
compress = 15
break
match = _WS_EXT_RE.match(defext)
if match:
compress = 15
if isserver:
# Server never fail to detect compress handshake.
# Server does not need to send max wbit to client
if match.group(4):
compress = int(match.group(4))
# Group3 must match if group4 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# CONTINUE to next extension
if compress > 15 or compress < 9:
compress = 0
continue
if match.group(1):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
else:
if match.group(6):
compress = int(match.group(6))
# Group5 must match if group6 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# FAIL the parse progress
if compress > 15 or compress < 9:
raise WSHandshakeError('Invalid window size')
if match.group(2):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
# Return Fail if client side and not match
elif not isserver:
raise WSHandshakeError('Extension for deflate not supported' +
ext.group(1))
return compress, notakeover
def ws_ext_gen(compress: int=15, isserver: bool=False,
server_notakeover: bool=False) -> str:
# client_notakeover=False not used for server
# compress wbit 8 does not support in zlib
if compress < 9 or compress > 15:
raise ValueError('Compress wbits must between 9 and 15, '
'zlib does not support wbits=8')
enabledext = ['permessage-deflate']
if not isserver:
enabledext.append('client_max_window_bits')
if compress < 15:
enabledext.append('server_max_window_bits=' + str(compress))
if server_notakeover:
enabledext.append('server_no_context_takeover')
# if client_notakeover:
# enabledext.append('client_no_context_takeover')
return '; '.join(enabledext)
class WSParserState(IntEnum):
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
class WebSocketReader:
def __init__(self, queue: DataQueue[WSMessage],
max_msg_size: int, compress: bool=True) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._exc = None # type: Optional[BaseException]
self._partial = bytearray()
self._state = WSParserState.READ_HEADER
self._opcode = None # type: Optional[int]
self._frame_fin = False
self._frame_opcode = None # type: Optional[int]
self._frame_payload = bytearray()
self._tail = b''
self._has_mask = False
self._frame_mask = None # type: Optional[bytes]
self._payload_length = 0
self._payload_length_flag = 0
self._compressed = None # type: Optional[bool]
self._decompressobj = None # type: Any # zlib.decompressobj actually
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
if self._exc:
return True, data
try:
return self._feed_data(data)
except Exception as exc:
self._exc = exc
self.queue.set_exception(exc)
return True, b''
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
for fin, opcode, payload, compressed in self.parse_frame(data):
if compressed and not self._decompressobj:
self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
if opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if (close_code < 3000 and
close_code not in ALLOWED_CLOSE_CODES):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close code: {}'.format(close_code))
try:
close_message = payload[2:].decode('utf-8')
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close frame: {} {} {!r}'.format(
fin, opcode, payload))
else:
msg = WSMessage(WSMsgType.CLOSE, 0, '')
self.queue.feed_data(msg, 0)
elif opcode == WSMsgType.PING:
self.queue.feed_data(
WSMessage(WSMsgType.PING, payload, ''), len(payload))
elif opcode == WSMsgType.PONG:
self.queue.feed_data(
WSMessage(WSMsgType.PONG, payload, ''), len(payload))
elif opcode not in (
WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Unexpected opcode={!r}".format(opcode))
else:
# load text/binary
if not fin:
# got partial frame payload
if opcode != WSMsgType.CONTINUATION:
self._opcode = opcode
self._partial.extend(payload)
if (self._max_msg_size and
len(self._partial) >= self._max_msg_size):
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size))
else:
# previous frame was non finished
# we should get continuation opcode
if self._partial:
if opcode != WSMsgType.CONTINUATION:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'The opcode in non-fin frame is expected '
'to be zero, got {!r}'.format(opcode))
if opcode == WSMsgType.CONTINUATION:
assert self._opcode is not None
opcode = self._opcode
self._opcode = None
self._partial.extend(payload)
if (self._max_msg_size and
len(self._partial) >= self._max_msg_size):
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size))
# Decompress process must to be done after all packets
# received.
if compressed:
self._partial.extend(_WS_DEFLATE_TRAILING)
payload_merged = self._decompressobj.decompress(
self._partial, self._max_msg_size)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}"
.format(
self._max_msg_size + left,
self._max_msg_size
)
)
else:
payload_merged = bytes(self._partial)
self._partial.clear()
if opcode == WSMsgType.TEXT:
try:
text = payload_merged.decode('utf-8')
self.queue.feed_data(
WSMessage(WSMsgType.TEXT, text, ''), len(text))
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
else:
self.queue.feed_data(
WSMessage(WSMsgType.BINARY, payload_merged, ''),
len(payload_merged))
return False, b''
def parse_frame(self, buf: bytes) -> List[Tuple[bool, Optional[int],
bytearray,
Optional[bool]]]:
"""Return the next frame from the socket."""
frames = []
if self._tail:
buf, self._tail = self._tail + buf, b''
start_pos = 0
buf_length = len(buf)
while True:
# read header
if self._state == WSParserState.READ_HEADER:
if buf_length - start_pos >= 2:
data = buf[start_pos:start_pos+2]
start_pos += 2
first_byte, second_byte = data
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xf
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received frame with non-zero reserved bits')
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received fragmented control frame')
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7f
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Control frame payload cannot be '
'larger than 125 bytes')
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed is None:
self._compressed = True if rsv1 else False
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received frame with non-zero reserved bits')
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_length_flag = length
self._state = WSParserState.READ_PAYLOAD_LENGTH
else:
break
# read payload length
if self._state == WSParserState.READ_PAYLOAD_LENGTH:
length = self._payload_length_flag
if length == 126:
if buf_length - start_pos >= 2:
data = buf[start_pos:start_pos+2]
start_pos += 2
length = UNPACK_LEN2(data)[0]
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD)
else:
break
elif length > 126:
if buf_length - start_pos >= 8:
data = buf[start_pos:start_pos+8]
start_pos += 8
length = UNPACK_LEN3(data)[0]
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD)
else:
break
else:
self._payload_length = length
self._state = (
WSParserState.READ_PAYLOAD_MASK
if self._has_mask
else WSParserState.READ_PAYLOAD)
# read payload mask
if self._state == WSParserState.READ_PAYLOAD_MASK:
if buf_length - start_pos >= 4:
self._frame_mask = buf[start_pos:start_pos+4]
start_pos += 4
self._state = WSParserState.READ_PAYLOAD
else:
break
if self._state == WSParserState.READ_PAYLOAD:
length = self._payload_length
payload = self._frame_payload
chunk_len = buf_length - start_pos
if length >= chunk_len:
self._payload_length = length - chunk_len
payload.extend(buf[start_pos:])
start_pos = buf_length
else:
self._payload_length = 0
payload.extend(buf[start_pos:start_pos+length])
start_pos = start_pos + length
if self._payload_length == 0:
if self._has_mask:
assert self._frame_mask is not None
_websocket_mask(self._frame_mask, payload)
frames.append((
self._frame_fin,
self._frame_opcode,
payload,
self._compressed))
self._frame_payload = bytearray()
self._state = WSParserState.READ_HEADER
else:
break
self._tail = buf[start_pos:]
return frames
class WebSocketWriter:
def __init__(self, protocol: BaseProtocol, transport: asyncio.Transport, *,
use_mask: bool=False, limit: int=DEFAULT_LIMIT,
random: Any=random.Random(),
compress: int=0, notakeover: bool=False) -> None:
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.randrange = random.randrange
self.compress = compress
self.notakeover = notakeover
self._closing = False
self._limit = limit
self._output_size = 0
self._compressobj = None # type: Any # actually compressobj
async def _send_frame(self, message: bytes, opcode: int,
compress: Optional[int]=None) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing:
ws_logger.warning('websocket connection is closing.')
rsv = 0
# Only compress larger packets (disabled)
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if (compress or self.compress) and opcode < 8:
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = zlib.compressobj(wbits=-compress)
else: # self.compress
if not self._compressobj:
self._compressobj = zlib.compressobj(wbits=-self.compress)
compressobj = self._compressobj
message = compressobj.compress(message)
message = message + compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
rsv = rsv | 0x40
msg_length = len(message)
use_mask = self.use_mask
if use_mask:
mask_bit = 0x80
else:
mask_bit = 0
if msg_length < 126:
header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
elif msg_length < (1 << 16):
header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
else:
header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
if use_mask:
mask = self.randrange(0, 0xffffffff)
mask = mask.to_bytes(4, 'big')
message = bytearray(message)
_websocket_mask(mask, message)
self.transport.write(header + mask + message)
self._output_size += len(header) + len(mask) + len(message)
else:
if len(message) > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
else:
self.transport.write(header + message)
self._output_size += len(header) + len(message)
if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()
async def pong(self, message: bytes=b'') -> None:
"""Send pong message."""
if isinstance(message, str):
message = message.encode('utf-8')
await self._send_frame(message, WSMsgType.PONG)
async def ping(self, message: bytes=b'') -> None:
"""Send ping message."""
if isinstance(message, str):
message = message.encode('utf-8')
await self._send_frame(message, WSMsgType.PING)
async def send(self, message: Union[str, bytes],
binary: bool=False,
compress: Optional[int]=None) -> None:
"""Send a frame over the websocket with message as its payload."""
if isinstance(message, str):
message = message.encode('utf-8')
if binary:
await self._send_frame(message, WSMsgType.BINARY, compress)
else:
await self._send_frame(message, WSMsgType.TEXT, compress)
async def close(self, code: int=1000, message: bytes=b'') -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode('utf-8')
try:
await self._send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
finally:
self._closing = True

@ -1,172 +0,0 @@
"""Http related parsers and protocol."""
import asyncio
import collections
import zlib
from typing import Any, Awaitable, Callable, Optional, Union # noqa
from multidict import CIMultiDict # noqa
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
__all__ = ('StreamWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11')
HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor'])
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
class StreamWriter(AbstractStreamWriter):
def __init__(self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
on_chunk_sent: _T_OnChunkSent = None) -> None:
self._protocol = protocol
self._transport = protocol.transport
self.loop = loop
self.length = None
self.chunked = False
self.buffer_size = 0
self.output_size = 0
self._eof = False
self._compress = None # type: Any
self._drain_waiter = None
self._on_chunk_sent = on_chunk_sent # type: _T_OnChunkSent
@property
def transport(self) -> Optional[asyncio.Transport]:
return self._transport
@property
def protocol(self) -> BaseProtocol:
return self._protocol
def enable_chunking(self) -> None:
self.chunked = True
def enable_compression(self, encoding: str='deflate') -> None:
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
self._compress = zlib.compressobj(wbits=zlib_mode)
def _write(self, chunk: bytes) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size
if self._transport is None or self._transport.is_closing():
raise ConnectionResetError('Cannot write to closing transport')
self._transport.write(chunk)
async def write(self, chunk: bytes,
*, drain: bool=True, LIMIT: int=0x10000) -> None:
"""Writes chunk of data to a stream.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
if self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if self._compress is not None:
chunk = self._compress.compress(chunk)
if not chunk:
return
if self.length is not None:
chunk_len = len(chunk)
if self.length >= chunk_len:
self.length = self.length - chunk_len
else:
chunk = chunk[:self.length]
self.length = 0
if not chunk:
return
if chunk:
if self.chunked:
chunk_len_pre = ('%x\r\n' % len(chunk)).encode('ascii')
chunk = chunk_len_pre + chunk + b'\r\n'
self._write(chunk)
if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
await self.drain()
async def write_headers(self, status_line: str,
headers: 'CIMultiDict[str]') -> None:
"""Write request/response status and headers."""
# status + headers
buf = _serialize_headers(status_line, headers)
self._write(buf)
async def write_eof(self, chunk: bytes=b'') -> None:
if self._eof:
return
if chunk and self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if self._compress:
if chunk:
chunk = self._compress.compress(chunk)
chunk = chunk + self._compress.flush()
if chunk and self.chunked:
chunk_len = ('%x\r\n' % len(chunk)).encode('ascii')
chunk = chunk_len + chunk + b'\r\n0\r\n\r\n'
else:
if self.chunked:
if chunk:
chunk_len = ('%x\r\n' % len(chunk)).encode('ascii')
chunk = chunk_len + chunk + b'\r\n0\r\n\r\n'
else:
chunk = b'0\r\n\r\n'
if chunk:
self._write(chunk)
await self.drain()
self._eof = True
self._transport = None
async def drain(self) -> None:
"""Flush the write buffer.
The intended use is to write
await w.write(data)
await w.drain()
"""
if self._protocol.transport is not None:
await self._protocol._drain_helper()
def _py_serialize_headers(status_line: str,
headers: 'CIMultiDict[str]') -> bytes:
line = status_line + '\r\n' + ''.join(
[k + ': ' + v + '\r\n' for k, v in headers.items()])
return line.encode('utf-8') + b'\r\n'
_serialize_headers = _py_serialize_headers
try:
import aiohttp._http_writer as _http_writer # type: ignore
_c_serialize_headers = _http_writer._serialize_headers
if not NO_EXTENSIONS:
_serialize_headers = _c_serialize_headers
except ImportError:
pass

@ -1,44 +0,0 @@
import asyncio
import collections
from typing import Any, Optional
try:
from typing import Deque
except ImportError:
from typing_extensions import Deque # noqa
class EventResultOrError:
"""
This class wrappers the Event asyncio lock allowing either awake the
locked Tasks without any error or raising an exception.
thanks to @vorpalsmith for the simple design.
"""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._exc = None # type: Optional[BaseException]
self._event = asyncio.Event(loop=loop)
self._waiters = collections.deque() # type: Deque[asyncio.Future[Any]]
def set(self, exc: Optional[BaseException]=None) -> None:
self._exc = exc
self._event.set()
async def wait(self) -> Any:
waiter = self._loop.create_task(self._event.wait())
self._waiters.append(waiter)
try:
val = await waiter
finally:
self._waiters.remove(waiter)
if self._exc is not None:
raise self._exc
return val
def cancel(self) -> None:
""" Cancel all waiters """
for waiter in self._waiters:
waiter.cancel()

File diff suppressed because it is too large Load Diff

@ -0,0 +1,495 @@
"""Parser is a generator function (NOT coroutine).
Parser receives data with generator's send() method and sends data to
destination DataQueue. Parser receives ParserBuffer and DataQueue objects
as a parameters of the parser call, all subsequent send() calls should
send bytes objects. Parser sends parsed `term` to destination buffer with
DataQueue.feed_data() method. DataQueue object should implement two methods.
feed_data() - parser uses this method to send parsed protocol data.
feed_eof() - parser uses this method for indication of end of parsing stream.
To indicate end of incoming data stream EofStream exception should be sent
into parser. Parser could throw exceptions.
There are three stages:
* Data flow chain:
1. Application creates StreamParser object for storing incoming data.
2. StreamParser creates ParserBuffer as internal data buffer.
3. Application create parser and set it into stream buffer:
parser = HttpRequestParser()
data_queue = stream.set_parser(parser)
3. At this stage StreamParser creates DataQueue object and passes it
and internal buffer into parser as an arguments.
def set_parser(self, parser):
output = DataQueue()
self.p = parser(output, self._input)
return output
4. Application waits data on output.read()
while True:
msg = yield from output.read()
...
* Data flow:
1. asyncio's transport reads data from socket and sends data to protocol
with data_received() call.
2. Protocol sends data to StreamParser with feed_data() call.
3. StreamParser sends data into parser with generator's send() method.
4. Parser processes incoming data and sends parsed data
to DataQueue with feed_data()
5. Application received parsed data from DataQueue.read()
* Eof:
1. StreamParser receives eof with feed_eof() call.
2. StreamParser throws EofStream exception into parser.
3. Then it unsets parser.
_SocketSocketTransport ->
-> "protocol" -> StreamParser -> "parser" -> DataQueue <- "application"
"""
import asyncio
import asyncio.streams
import inspect
import socket
from . import errors
from .streams import EofStream, FlowControlDataQueue
__all__ = ('EofStream', 'StreamParser', 'StreamProtocol',
'ParserBuffer', 'StreamWriter')
DEFAULT_LIMIT = 2 ** 16
if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH
else: # pragma: no cover
CORK = None
class StreamParser:
"""StreamParser manages incoming bytes stream and protocol parsers.
StreamParser uses ParserBuffer as internal buffer.
set_parser() sets current parser, it creates DataQueue object
and sends ParserBuffer and DataQueue into parser generator.
unset_parser() sends EofStream into parser and then removes it.
"""
def __init__(self, *, loop=None, buf=None,
limit=DEFAULT_LIMIT, eof_exc_class=RuntimeError, **kwargs):
self._loop = loop
self._eof = False
self._exception = None
self._parser = None
self._output = None
self._limit = limit
self._eof_exc_class = eof_exc_class
self._buffer = buf if buf is not None else ParserBuffer()
self.paused = False
self.transport = None
@property
def output(self):
return self._output
def set_transport(self, transport):
assert transport is None or self.transport is None, \
'Transport already set'
self.transport = transport
def at_eof(self):
return self._eof
def exception(self):
return self._exception
def set_exception(self, exc):
if isinstance(exc, ConnectionError):
exc, old_exc = self._eof_exc_class(), exc
exc.__cause__ = old_exc
exc.__context__ = old_exc
self._exception = exc
if self._output is not None:
self._output.set_exception(exc)
self._output = None
self._parser = None
def feed_data(self, data):
"""send data to current parser or store in buffer."""
if data is None:
return
if self._parser:
try:
self._parser.send(data)
except StopIteration:
self._output.feed_eof()
self._output = None
self._parser = None
except Exception as exc:
self._output.set_exception(exc)
self._output = None
self._parser = None
else:
self._buffer.feed_data(data)
def feed_eof(self):
"""send eof to all parsers, recursively."""
if self._parser:
try:
if self._buffer:
self._parser.send(b'')
self._parser.throw(EofStream())
except StopIteration:
self._output.feed_eof()
except EofStream:
self._output.set_exception(self._eof_exc_class())
except Exception as exc:
self._output.set_exception(exc)
self._parser = None
self._output = None
self._eof = True
def set_parser(self, parser, output=None):
"""set parser to stream. return parser's DataQueue."""
if self._parser:
self.unset_parser()
if output is None:
output = FlowControlDataQueue(
self, limit=self._limit, loop=self._loop)
if self._exception:
output.set_exception(self._exception)
return output
# init parser
p = parser(output, self._buffer)
assert inspect.isgenerator(p), 'Generator is required'
try:
# initialize parser with data and parser buffers
next(p)
except StopIteration:
pass
except Exception as exc:
output.set_exception(exc)
else:
# parser still require more data
self._parser = p
self._output = output
if self._eof:
self.unset_parser()
return output
def unset_parser(self):
"""unset parser, send eof to the parser and then remove it."""
if self._parser is None:
return
# TODO: write test
if self._loop.is_closed():
# TODO: log something
return
try:
self._parser.throw(EofStream())
except StopIteration:
self._output.feed_eof()
except EofStream:
self._output.set_exception(self._eof_exc_class())
except Exception as exc:
self._output.set_exception(exc)
finally:
self._output = None
self._parser = None
class StreamWriter(asyncio.streams.StreamWriter):
def __init__(self, transport, protocol, reader, loop):
self._transport = transport
self._protocol = protocol
self._reader = reader
self._loop = loop
self._tcp_nodelay = False
self._tcp_cork = False
self._socket = transport.get_extra_info('socket')
@property
def tcp_nodelay(self):
return self._tcp_nodelay
def set_tcp_nodelay(self, value):
value = bool(value)
if self._tcp_nodelay == value:
return
self._tcp_nodelay = value
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return
if self._tcp_cork:
self._tcp_cork = False
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
@property
def tcp_cork(self):
return self._tcp_cork
def set_tcp_cork(self, value):
value = bool(value)
if self._tcp_cork == value:
return
self._tcp_cork = value
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return
if self._tcp_nodelay:
self._socket.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY,
False)
self._tcp_nodelay = False
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value)
class StreamProtocol(asyncio.streams.FlowControlMixin, asyncio.Protocol):
"""Helper class to adapt between Protocol and StreamReader."""
def __init__(self, *, loop=None, disconnect_error=RuntimeError, **kwargs):
super().__init__(loop=loop)
self.transport = None
self.writer = None
self.reader = StreamParser(
loop=loop, eof_exc_class=disconnect_error, **kwargs)
def is_connected(self):
return self.transport is not None
def connection_made(self, transport):
self.transport = transport
self.reader.set_transport(transport)
self.writer = StreamWriter(transport, self, self.reader, self._loop)
def connection_lost(self, exc):
self.transport = self.writer = None
self.reader.set_transport(None)
if exc is None:
self.reader.feed_eof()
else:
self.reader.set_exception(exc)
super().connection_lost(exc)
def data_received(self, data):
self.reader.feed_data(data)
def eof_received(self):
self.reader.feed_eof()
class _ParserBufferHelper:
__slots__ = ('exception', 'data')
def __init__(self, exception, data):
self.exception = exception
self.data = data
class ParserBuffer:
"""ParserBuffer is NOT a bytearray extension anymore.
ParserBuffer provides helper methods for parsers.
"""
__slots__ = ('_helper', '_writer', '_data')
def __init__(self, *args):
self._data = bytearray(*args)
self._helper = _ParserBufferHelper(None, self._data)
self._writer = self._feed_data(self._helper)
next(self._writer)
def exception(self):
return self._helper.exception
def set_exception(self, exc):
self._helper.exception = exc
@staticmethod
def _feed_data(helper):
while True:
chunk = yield
if chunk:
helper.data.extend(chunk)
if helper.exception:
raise helper.exception
def feed_data(self, data):
if not self._helper.exception:
self._writer.send(data)
def read(self, size):
"""read() reads specified amount of bytes."""
while True:
if self._helper.exception:
raise self._helper.exception
if len(self._data) >= size:
data = self._data[:size]
del self._data[:size]
return data
self._writer.send((yield))
def readsome(self, size=None):
"""reads size of less amount of bytes."""
while True:
if self._helper.exception:
raise self._helper.exception
length = len(self._data)
if length > 0:
if size is None or length < size:
size = length
data = self._data[:size]
del self._data[:size]
return data
self._writer.send((yield))
def readuntil(self, stop, limit=None):
assert isinstance(stop, bytes) and stop, \
'bytes is required: {!r}'.format(stop)
stop_len = len(stop)
while True:
if self._helper.exception:
raise self._helper.exception
pos = self._data.find(stop)
if pos >= 0:
end = pos + stop_len
size = end
if limit is not None and size > limit:
raise errors.LineLimitExceededParserError(
'Line is too long.', limit)
data = self._data[:size]
del self._data[:size]
return data
else:
if limit is not None and len(self._data) > limit:
raise errors.LineLimitExceededParserError(
'Line is too long.', limit)
self._writer.send((yield))
def wait(self, size):
"""wait() waits for specified amount of bytes
then returns data without changing internal buffer."""
while True:
if self._helper.exception:
raise self._helper.exception
if len(self._data) >= size:
return self._data[:size]
self._writer.send((yield))
def waituntil(self, stop, limit=None):
"""waituntil() reads until `stop` bytes sequence."""
assert isinstance(stop, bytes) and stop, \
'bytes is required: {!r}'.format(stop)
stop_len = len(stop)
while True:
if self._helper.exception:
raise self._helper.exception
pos = self._data.find(stop)
if pos >= 0:
size = pos + stop_len
if limit is not None and size > limit:
raise errors.LineLimitExceededParserError(
'Line is too long. %s' % bytes(self._data), limit)
return self._data[:size]
else:
if limit is not None and len(self._data) > limit:
raise errors.LineLimitExceededParserError(
'Line is too long. %s' % bytes(self._data), limit)
self._writer.send((yield))
def skip(self, size):
"""skip() skips specified amount of bytes."""
while len(self._data) < size:
if self._helper.exception:
raise self._helper.exception
self._writer.send((yield))
del self._data[:size]
def skipuntil(self, stop):
"""skipuntil() reads until `stop` bytes sequence."""
assert isinstance(stop, bytes) and stop, \
'bytes is required: {!r}'.format(stop)
stop_len = len(stop)
while True:
if self._helper.exception:
raise self._helper.exception
stop_line = self._data.find(stop)
if stop_line >= 0:
size = stop_line + stop_len
del self._data[:size]
return
self._writer.send((yield))
def extend(self, data):
self._data.extend(data)
def __len__(self):
return len(self._data)
def __bytes__(self):
return bytes(self._data)

@ -1,456 +0,0 @@
import asyncio
import enum
import io
import json
import mimetypes
import os
import warnings
from abc import ABC, abstractmethod
from itertools import chain
from typing import (
IO,
TYPE_CHECKING,
Any,
ByteString,
Dict,
Iterable,
Optional,
Text,
TextIO,
Tuple,
Type,
Union,
)
from multidict import CIMultiDict
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import (
PY_36,
content_disposition_header,
guess_filename,
parse_mimetype,
sentinel,
)
from .streams import DEFAULT_LIMIT, StreamReader
from .typedefs import JSONEncoder, _CIMultiDict
__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload',
'BytesPayload', 'StringPayload',
'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload',
'TextIOPayload', 'StringIOPayload', 'JsonPayload',
'AsyncIterablePayload')
TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB
if TYPE_CHECKING: # pragma: no cover
from typing import List # noqa
class LookupError(Exception):
pass
class Order(str, enum.Enum):
normal = 'normal'
try_first = 'try_first'
try_last = 'try_last'
def get_payload(data: Any, *args: Any, **kwargs: Any) -> 'Payload':
return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
def register_payload(factory: Type['Payload'],
type: Any,
*,
order: Order=Order.normal) -> None:
PAYLOAD_REGISTRY.register(factory, type, order=order)
class payload_type:
def __init__(self, type: Any, *, order: Order=Order.normal) -> None:
self.type = type
self.order = order
def __call__(self, factory: Type['Payload']) -> Type['Payload']:
register_payload(factory, self.type, order=self.order)
return factory
class PayloadRegistry:
"""Payload registry.
note: we need zope.interface for more efficient adapter search
"""
def __init__(self) -> None:
self._first = [] # type: List[Tuple[Type[Payload], Any]]
self._normal = [] # type: List[Tuple[Type[Payload], Any]]
self._last = [] # type: List[Tuple[Type[Payload], Any]]
def get(self,
data: Any,
*args: Any,
_CHAIN: Any=chain,
**kwargs: Any) -> 'Payload':
if isinstance(data, Payload):
return data
for factory, type in _CHAIN(self._first, self._normal, self._last):
if isinstance(data, type):
return factory(data, *args, **kwargs)
raise LookupError()
def register(self,
factory: Type['Payload'],
type: Any,
*,
order: Order=Order.normal) -> None:
if order is Order.try_first:
self._first.append((factory, type))
elif order is Order.normal:
self._normal.append((factory, type))
elif order is Order.try_last:
self._last.append((factory, type))
else:
raise ValueError("Unsupported order {!r}".format(order))
class Payload(ABC):
_default_content_type = 'application/octet-stream' # type: str
_size = None # type: Optional[int]
def __init__(self,
value: Any,
headers: Optional[
Union[
_CIMultiDict,
Dict[str, str],
Iterable[Tuple[str, str]]
]
] = None,
content_type: Optional[str]=sentinel,
filename: Optional[str]=None,
encoding: Optional[str]=None,
**kwargs: Any) -> None:
self._encoding = encoding
self._filename = filename
self._headers = CIMultiDict() # type: _CIMultiDict
self._value = value
if content_type is not sentinel and content_type is not None:
self._headers[hdrs.CONTENT_TYPE] = content_type
elif self._filename is not None:
content_type = mimetypes.guess_type(self._filename)[0]
if content_type is None:
content_type = self._default_content_type
self._headers[hdrs.CONTENT_TYPE] = content_type
else:
self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
self._headers.update(headers or {})
@property
def size(self) -> Optional[int]:
"""Size of the payload."""
return self._size
@property
def filename(self) -> Optional[str]:
"""Filename of the payload."""
return self._filename
@property
def headers(self) -> _CIMultiDict:
"""Custom item headers"""
return self._headers
@property
def _binary_headers(self) -> bytes:
return ''.join(
[k + ': ' + v + '\r\n' for k, v in self.headers.items()]
).encode('utf-8') + b'\r\n'
@property
def encoding(self) -> Optional[str]:
"""Payload encoding"""
return self._encoding
@property
def content_type(self) -> str:
"""Content type"""
return self._headers[hdrs.CONTENT_TYPE]
def set_content_disposition(self,
disptype: str,
quote_fields: bool=True,
**params: Any) -> None:
"""Sets ``Content-Disposition`` header."""
self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
disptype, quote_fields=quote_fields, **params)
@abstractmethod
async def write(self, writer: AbstractStreamWriter) -> None:
"""Write payload.
writer is an AbstractStreamWriter instance:
"""
class BytesPayload(Payload):
def __init__(self,
value: ByteString,
*args: Any,
**kwargs: Any) -> None:
if not isinstance(value, (bytes, bytearray, memoryview)):
raise TypeError("value argument must be byte-ish, not {!r}"
.format(type(value)))
if 'content_type' not in kwargs:
kwargs['content_type'] = 'application/octet-stream'
super().__init__(value, *args, **kwargs)
self._size = len(value)
if self._size > TOO_LARGE_BYTES_BODY:
if PY_36:
kwargs = {'source': self}
else:
kwargs = {}
warnings.warn("Sending a large body directly with raw bytes might"
" lock the event loop. You should probably pass an "
"io.BytesIO object instead", ResourceWarning,
**kwargs)
async def write(self, writer: AbstractStreamWriter) -> None:
await writer.write(self._value)
class StringPayload(BytesPayload):
def __init__(self,
value: Text,
*args: Any,
encoding: Optional[str]=None,
content_type: Optional[str]=None,
**kwargs: Any) -> None:
if encoding is None:
if content_type is None:
real_encoding = 'utf-8'
content_type = 'text/plain; charset=utf-8'
else:
mimetype = parse_mimetype(content_type)
real_encoding = mimetype.parameters.get('charset', 'utf-8')
else:
if content_type is None:
content_type = 'text/plain; charset=%s' % encoding
real_encoding = encoding
super().__init__(
value.encode(real_encoding),
encoding=real_encoding,
content_type=content_type,
*args,
**kwargs,
)
class StringIOPayload(StringPayload):
def __init__(self,
value: IO[str],
*args: Any,
**kwargs: Any) -> None:
super().__init__(value.read(), *args, **kwargs)
class IOBasePayload(Payload):
def __init__(self,
value: IO[Any],
disposition: str='attachment',
*args: Any,
**kwargs: Any) -> None:
if 'filename' not in kwargs:
kwargs['filename'] = guess_filename(value)
super().__init__(value, *args, **kwargs)
if self._filename is not None and disposition is not None:
if hdrs.CONTENT_DISPOSITION not in self.headers:
self.set_content_disposition(
disposition, filename=self._filename
)
async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
)
finally:
await loop.run_in_executor(None, self._value.close)
class TextIOPayload(IOBasePayload):
def __init__(self,
value: TextIO,
*args: Any,
encoding: Optional[str]=None,
content_type: Optional[str]=None,
**kwargs: Any) -> None:
if encoding is None:
if content_type is None:
encoding = 'utf-8'
content_type = 'text/plain; charset=utf-8'
else:
mimetype = parse_mimetype(content_type)
encoding = mimetype.parameters.get('charset', 'utf-8')
else:
if content_type is None:
content_type = 'text/plain; charset=%s' % encoding
super().__init__(
value,
content_type=content_type,
encoding=encoding,
*args,
**kwargs,
)
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except OSError:
return None
async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
)
while chunk:
await writer.write(chunk.encode(self._encoding))
chunk = await loop.run_in_executor(
None, self._value.read, DEFAULT_LIMIT
)
finally:
await loop.run_in_executor(None, self._value.close)
class BytesIOPayload(IOBasePayload):
@property
def size(self) -> int:
position = self._value.tell()
end = self._value.seek(0, os.SEEK_END)
self._value.seek(position)
return end - position
class BufferedReaderPayload(IOBasePayload):
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except OSError:
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
return None
class JsonPayload(BytesPayload):
def __init__(self,
value: Any,
encoding: str='utf-8',
content_type: str='application/json',
dumps: JSONEncoder=json.dumps,
*args: Any,
**kwargs: Any) -> None:
super().__init__(
dumps(value).encode(encoding),
content_type=content_type, encoding=encoding, *args, **kwargs)
if TYPE_CHECKING: # pragma: no cover
from typing import AsyncIterator, AsyncIterable
_AsyncIterator = AsyncIterator[bytes]
_AsyncIterable = AsyncIterable[bytes]
else:
from collections.abc import AsyncIterable, AsyncIterator
_AsyncIterator = AsyncIterator
_AsyncIterable = AsyncIterable
class AsyncIterablePayload(Payload):
_iter = None # type: Optional[_AsyncIterator]
def __init__(self,
value: _AsyncIterable,
*args: Any,
**kwargs: Any) -> None:
if not isinstance(value, AsyncIterable):
raise TypeError("value argument must support "
"collections.abc.AsyncIterablebe interface, "
"got {!r}".format(type(value)))
if 'content_type' not in kwargs:
kwargs['content_type'] = 'application/octet-stream'
super().__init__(value, *args, **kwargs)
self._iter = value.__aiter__()
async def write(self, writer: AbstractStreamWriter) -> None:
if self._iter:
try:
# iter is not None check prevents rare cases
# when the case iterable is used twice
while True:
chunk = await self._iter.__anext__()
await writer.write(chunk)
except StopAsyncIteration:
self._iter = None
class StreamReaderPayload(AsyncIterablePayload):
def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
super().__init__(value.iter_any(), *args, **kwargs)
PAYLOAD_REGISTRY = PayloadRegistry()
PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
PAYLOAD_REGISTRY.register(StringPayload, str)
PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
PAYLOAD_REGISTRY.register(
BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
# try_last for giving a chance to more specialized async interables like
# multidict.BodyPartReaderPayload override the default
PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable,
order=Order.try_last)

@ -1,74 +0,0 @@
""" Payload implemenation for coroutines as data provider.
As a simple case, you can upload data from file::
@aiohttp.streamer
async def file_sender(writer, file_name=None):
with open(file_name, 'rb') as f:
chunk = f.read(2**16)
while chunk:
await writer.write(chunk)
chunk = f.read(2**16)
Then you can use `file_sender` like this:
async with session.post('http://httpbin.org/post',
data=file_sender(file_name='huge_file')) as resp:
print(await resp.text())
..note:: Coroutine must accept `writer` as first argument
"""
import asyncio
import warnings
from typing import Any, Awaitable, Callable, Dict, Tuple
from .abc import AbstractStreamWriter
from .payload import Payload, payload_type
__all__ = ('streamer',)
class _stream_wrapper:
def __init__(self,
coro: Callable[..., Awaitable[None]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any]) -> None:
self.coro = asyncio.coroutine(coro)
self.args = args
self.kwargs = kwargs
async def __call__(self, writer: AbstractStreamWriter) -> None:
await self.coro(writer, *self.args, **self.kwargs)
class streamer:
def __init__(self, coro: Callable[..., Awaitable[None]]) -> None:
warnings.warn("@streamer is deprecated, use async generators instead",
DeprecationWarning,
stacklevel=2)
self.coro = coro
def __call__(self, *args: Any, **kwargs: Any) -> _stream_wrapper:
return _stream_wrapper(self.coro, args, kwargs)
@payload_type(_stream_wrapper)
class StreamWrapperPayload(Payload):
async def write(self, writer: AbstractStreamWriter) -> None:
await self._value(writer)
@payload_type(streamer)
class StreamPayload(StreamWrapperPayload):
def __init__(self, value: Any, *args: Any, **kwargs: Any) -> None:
super().__init__(value(), *args, **kwargs)
async def write(self, writer: AbstractStreamWriter) -> None:
await self._value(writer)

@ -0,0 +1,916 @@
"""Http related parsers and protocol."""
import collections
import functools
import http.server
import re
import string
import sys
import zlib
from abc import ABC, abstractmethod
from wsgiref.handlers import format_date_time
from multidict import CIMultiDict, istr
import aiohttp
from . import errors, hdrs
from .helpers import reify
from .log import internal_logger
__all__ = ('HttpMessage', 'Request', 'Response',
'HttpVersion', 'HttpVersion10', 'HttpVersion11',
'RawRequestMessage', 'RawResponseMessage',
'HttpPrefixParser', 'HttpRequestParser', 'HttpResponseParser',
'HttpPayloadParser')
ASCIISET = set(string.printable)
METHRE = re.compile('[A-Z0-9$-_.]+')
VERSRE = re.compile('HTTP/(\d+).(\d+)')
HDRRE = re.compile(b'[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]')
EOF_MARKER = object()
EOL_MARKER = object()
STATUS_LINE_READY = object()
RESPONSES = http.server.BaseHTTPRequestHandler.responses
HttpVersion = collections.namedtuple(
'HttpVersion', ['major', 'minor'])
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)
RawStatusLineMessage = collections.namedtuple(
'RawStatusLineMessage', ['method', 'path', 'version'])
RawRequestMessage = collections.namedtuple(
'RawRequestMessage',
['method', 'path', 'version', 'headers', 'raw_headers',
'should_close', 'compression'])
RawResponseMessage = collections.namedtuple(
'RawResponseMessage',
['version', 'code', 'reason', 'headers', 'raw_headers',
'should_close', 'compression'])
class HttpParser:
def __init__(self, max_line_size=8190, max_headers=32768,
max_field_size=8190):
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size
def parse_headers(self, lines):
"""Parses RFC 5322 headers from a stream.
Line continuations are supported. Returns list of header name
and value pairs. Header name is in upper case.
"""
close_conn = None
encoding = None
headers = CIMultiDict()
raw_headers = []
lines_idx = 1
line = lines[1]
while line:
header_length = len(line)
# Parse initial header name : value pair.
try:
bname, bvalue = line.split(b':', 1)
except ValueError:
raise errors.InvalidHeader(line) from None
bname = bname.strip(b' \t').upper()
if HDRRE.search(bname):
raise errors.InvalidHeader(bname)
# next line
lines_idx += 1
line = lines[lines_idx]
# consume continuation lines
continuation = line and line[0] in (32, 9) # (' ', '\t')
if continuation:
bvalue = [bvalue]
while continuation:
header_length += len(line)
if header_length > self.max_field_size:
raise errors.LineTooLong(
'limit request headers fields size')
bvalue.append(line)
# next line
lines_idx += 1
line = lines[lines_idx]
continuation = line[0] in (32, 9) # (' ', '\t')
bvalue = b'\r\n'.join(bvalue)
else:
if header_length > self.max_field_size:
raise errors.LineTooLong(
'limit request headers fields size')
bvalue = bvalue.strip()
name = istr(bname.decode('utf-8', 'surrogateescape'))
value = bvalue.decode('utf-8', 'surrogateescape')
# keep-alive and encoding
if name == hdrs.CONNECTION:
v = value.lower()
if v == 'close':
close_conn = True
elif v == 'keep-alive':
close_conn = False
elif name == hdrs.CONTENT_ENCODING:
enc = value.lower()
if enc in ('gzip', 'deflate'):
encoding = enc
headers.add(name, value)
raw_headers.append((bname, bvalue))
return headers, raw_headers, close_conn, encoding
class HttpPrefixParser:
"""Waits for 'HTTP' prefix (non destructive)"""
def __init__(self, allowed_methods=()):
self.allowed_methods = [m.upper() for m in allowed_methods]
def __call__(self, out, buf):
raw_data = yield from buf.waituntil(b' ', 12)
method = raw_data.decode('ascii', 'surrogateescape').strip()
# method
method = method.upper()
if not METHRE.match(method):
raise errors.BadStatusLine(method)
# allowed method
if self.allowed_methods and method not in self.allowed_methods:
raise errors.HttpMethodNotAllowed(message=method)
out.feed_data(method, len(method))
out.feed_eof()
class HttpRequestParser(HttpParser):
"""Read request status line. Exception errors.BadStatusLine
could be raised in case of any errors in status line.
Returns RawRequestMessage.
"""
def __call__(self, out, buf):
# read HTTP message (request line + headers)
try:
raw_data = yield from buf.readuntil(
b'\r\n\r\n', self.max_headers)
except errors.LineLimitExceededParserError as exc:
raise errors.LineTooLong(exc.limit) from None
lines = raw_data.split(b'\r\n')
# request line
line = lines[0].decode('utf-8', 'surrogateescape')
try:
method, path, version = line.split(None, 2)
except ValueError:
raise errors.BadStatusLine(line) from None
# method
method = method.upper()
if not METHRE.match(method):
raise errors.BadStatusLine(method)
# version
try:
if version.startswith('HTTP/'):
n1, n2 = version[5:].split('.', 1)
version = HttpVersion(int(n1), int(n2))
else:
raise errors.BadStatusLine(version)
except:
raise errors.BadStatusLine(version)
# read headers
headers, raw_headers, close, compression = self.parse_headers(lines)
if close is None: # then the headers weren't set in the request
if version <= HttpVersion10: # HTTP 1.0 must asks to not close
close = True
else: # HTTP 1.1 must ask to close.
close = False
out.feed_data(
RawRequestMessage(
method, path, version, headers, raw_headers,
close, compression),
len(raw_data))
out.feed_eof()
class HttpResponseParser(HttpParser):
"""Read response status line and headers.
BadStatusLine could be raised in case of any errors in status line.
Returns RawResponseMessage"""
def __call__(self, out, buf):
# read HTTP message (response line + headers)
try:
raw_data = yield from buf.readuntil(
b'\r\n\r\n', self.max_line_size + self.max_headers)
except errors.LineLimitExceededParserError as exc:
raise errors.LineTooLong(exc.limit) from None
lines = raw_data.split(b'\r\n')
line = lines[0].decode('utf-8', 'surrogateescape')
try:
version, status = line.split(None, 1)
except ValueError:
raise errors.BadStatusLine(line) from None
else:
try:
status, reason = status.split(None, 1)
except ValueError:
reason = ''
# version
match = VERSRE.match(version)
if match is None:
raise errors.BadStatusLine(line)
version = HttpVersion(int(match.group(1)), int(match.group(2)))
# The status code is a three-digit number
try:
status = int(status)
except ValueError:
raise errors.BadStatusLine(line) from None
if status < 100 or status > 999:
raise errors.BadStatusLine(line)
# read headers
headers, raw_headers, close, compression = self.parse_headers(lines)
if close is None:
close = version <= HttpVersion10
out.feed_data(
RawResponseMessage(
version, status, reason.strip(),
headers, raw_headers, close, compression),
len(raw_data))
out.feed_eof()
class HttpPayloadParser:
def __init__(self, message, length=None, compression=True,
readall=False, response_with_body=True):
self.message = message
self.length = length
self.compression = compression
self.readall = readall
self.response_with_body = response_with_body
def __call__(self, out, buf):
# payload params
length = self.message.headers.get(hdrs.CONTENT_LENGTH, self.length)
if hdrs.SEC_WEBSOCKET_KEY1 in self.message.headers:
length = 8
# payload decompression wrapper
if (self.response_with_body and
self.compression and self.message.compression):
out = DeflateBuffer(out, self.message.compression)
# payload parser
if not self.response_with_body:
# don't parse payload if it's not expected to be received
pass
elif 'chunked' in self.message.headers.get(
hdrs.TRANSFER_ENCODING, ''):
yield from self.parse_chunked_payload(out, buf)
elif length is not None:
try:
length = int(length)
except ValueError:
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None
if length < 0:
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH)
elif length > 0:
yield from self.parse_length_payload(out, buf, length)
else:
if self.readall and getattr(self.message, 'code', 0) != 204:
yield from self.parse_eof_payload(out, buf)
elif getattr(self.message, 'method', None) in ('PUT', 'POST'):
internal_logger.warning( # pragma: no cover
'Content-Length or Transfer-Encoding header is required')
out.feed_eof()
def parse_chunked_payload(self, out, buf):
"""Chunked transfer encoding parser."""
while True:
# read next chunk size
line = yield from buf.readuntil(b'\r\n', 8192)
i = line.find(b';')
if i >= 0:
line = line[:i] # strip chunk-extensions
else:
line = line.strip()
try:
size = int(line, 16)
except ValueError:
raise errors.TransferEncodingError(line) from None
if size == 0: # eof marker
break
# read chunk and feed buffer
while size:
chunk = yield from buf.readsome(size)
out.feed_data(chunk, len(chunk))
size = size - len(chunk)
# toss the CRLF at the end of the chunk
yield from buf.skip(2)
# read and discard trailer up to the CRLF terminator
yield from buf.skipuntil(b'\r\n')
def parse_length_payload(self, out, buf, length=0):
"""Read specified amount of bytes."""
required = length
while required:
chunk = yield from buf.readsome(required)
out.feed_data(chunk, len(chunk))
required -= len(chunk)
def parse_eof_payload(self, out, buf):
"""Read all bytes until eof."""
try:
while True:
chunk = yield from buf.readsome()
out.feed_data(chunk, len(chunk))
except aiohttp.EofStream:
pass
class DeflateBuffer:
"""DeflateStream decompress stream and feed data into specified stream."""
def __init__(self, out, encoding):
self.out = out
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
self.zlib = zlib.decompressobj(wbits=zlib_mode)
def feed_data(self, chunk, size):
try:
chunk = self.zlib.decompress(chunk)
except Exception:
raise errors.ContentEncodingError('deflate')
if chunk:
self.out.feed_data(chunk, len(chunk))
def feed_eof(self):
chunk = self.zlib.flush()
self.out.feed_data(chunk, len(chunk))
if not self.zlib.eof:
raise errors.ContentEncodingError('deflate')
self.out.feed_eof()
def wrap_payload_filter(func):
"""Wraps payload filter and piped filters.
Filter is a generator that accepts arbitrary chunks of data,
modify data and emit new stream of data.
For example we have stream of chunks: ['1', '2', '3', '4', '5'],
we can apply chunking filter to this stream:
['1', '2', '3', '4', '5']
|
response.add_chunking_filter(2)
|
['12', '34', '5']
It is possible to use different filters at the same time.
For a example to compress incoming stream with 'deflate' encoding
and then split data and emit chunks of 8192 bytes size chunks:
>>> response.add_compression_filter('deflate')
>>> response.add_chunking_filter(8192)
Filters do not alter transfer encoding.
Filter can receive types types of data, bytes object or EOF_MARKER.
1. If filter receives bytes object, it should process data
and yield processed data then yield EOL_MARKER object.
2. If Filter received EOF_MARKER, it should yield remaining
data (buffered) and then yield EOF_MARKER.
"""
@functools.wraps(func)
def wrapper(self, *args, **kw):
new_filter = func(self, *args, **kw)
filter = self.filter
if filter is not None:
next(new_filter)
self.filter = filter_pipe(filter, new_filter)
else:
self.filter = new_filter
next(self.filter)
return wrapper
def filter_pipe(filter, filter2, *,
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
"""Creates pipe between two filters.
filter_pipe() feeds first filter with incoming data and then
send yielded from first filter data into filter2, results of
filter2 are being emitted.
1. If filter_pipe receives bytes object, it sends it to the first filter.
2. Reads yielded values from the first filter until it receives
EOF_MARKER or EOL_MARKER.
3. Each of this values is being send to second filter.
4. Reads yielded values from second filter until it receives EOF_MARKER
or EOL_MARKER. Each of this values yields to writer.
"""
chunk = yield
while True:
eof = chunk is EOF_MARKER
chunk = filter.send(chunk)
while chunk is not EOL_MARKER:
chunk = filter2.send(chunk)
while chunk not in (EOF_MARKER, EOL_MARKER):
yield chunk
chunk = next(filter2)
if chunk is not EOF_MARKER:
if eof:
chunk = EOF_MARKER
else:
chunk = next(filter)
else:
break
chunk = yield EOL_MARKER
class HttpMessage(ABC):
"""HttpMessage allows to write headers and payload to a stream.
For example, lets say we want to read file then compress it with deflate
compression and then send it with chunked transfer encoding, code may look
like this:
>>> response = aiohttp.Response(transport, 200)
We have to use deflate compression first:
>>> response.add_compression_filter('deflate')
Then we want to split output stream into chunks of 1024 bytes size:
>>> response.add_chunking_filter(1024)
We can add headers to response with add_headers() method. add_headers()
does not send data to transport, send_headers() sends request/response
line and then sends headers:
>>> response.add_headers(
... ('Content-Disposition', 'attachment; filename="..."'))
>>> response.send_headers()
Now we can use chunked writer to write stream to a network stream.
First call to write() method sends response status line and headers,
add_header() and add_headers() method unavailable at this stage:
>>> with open('...', 'rb') as f:
... chunk = fp.read(8192)
... while chunk:
... response.write(chunk)
... chunk = fp.read(8192)
>>> response.write_eof()
"""
writer = None
# 'filter' is being used for altering write() behaviour,
# add_chunking_filter adds deflate/gzip compression and
# add_compression_filter splits incoming data into a chunks.
filter = None
HOP_HEADERS = None # Must be set by subclass.
SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
sys.version_info, aiohttp.__version__)
upgrade = False # Connection: UPGRADE
websocket = False # Upgrade: WEBSOCKET
has_chunked_hdr = False # Transfer-encoding: chunked
# subclass can enable auto sending headers with write() call,
# this is useful for wsgi's start_response implementation.
_send_headers = False
def __init__(self, transport, version, close):
self.transport = transport
self._version = version
self.closing = close
self.keepalive = None
self.chunked = False
self.length = None
self.headers = CIMultiDict()
self.headers_sent = False
self.output_length = 0
self.headers_length = 0
self._output_size = 0
@property
@abstractmethod
def status_line(self):
return b''
@abstractmethod
def autochunked(self):
return False
@property
def version(self):
return self._version
@property
def body_length(self):
return self.output_length - self.headers_length
def force_close(self):
self.closing = True
self.keepalive = False
def enable_chunked_encoding(self):
self.chunked = True
def keep_alive(self):
if self.keepalive is None:
if self.version < HttpVersion10:
# keep alive not supported at all
return False
if self.version == HttpVersion10:
if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
return True
else: # no headers means we close for Http 1.0
return False
else:
return not self.closing
else:
return self.keepalive
def is_headers_sent(self):
return self.headers_sent
def add_header(self, name, value):
"""Analyze headers. Calculate content length,
removes hop headers, etc."""
assert not self.headers_sent, 'headers have been sent already'
assert isinstance(name, str), \
'Header name should be a string, got {!r}'.format(name)
assert set(name).issubset(ASCIISET), \
'Header name should contain ASCII chars, got {!r}'.format(name)
assert isinstance(value, str), \
'Header {!r} should have string value, got {!r}'.format(
name, value)
name = istr(name)
value = value.strip()
if name == hdrs.CONTENT_LENGTH:
self.length = int(value)
if name == hdrs.TRANSFER_ENCODING:
self.has_chunked_hdr = value.lower().strip() == 'chunked'
if name == hdrs.CONNECTION:
val = value.lower()
# handle websocket
if 'upgrade' in val:
self.upgrade = True
# connection keep-alive
elif 'close' in val:
self.keepalive = False
elif 'keep-alive' in val:
self.keepalive = True
elif name == hdrs.UPGRADE:
if 'websocket' in value.lower():
self.websocket = True
self.headers[name] = value
elif name not in self.HOP_HEADERS:
# ignore hop-by-hop headers
self.headers.add(name, value)
def add_headers(self, *headers):
"""Adds headers to a HTTP message."""
for name, value in headers:
self.add_header(name, value)
def send_headers(self, _sep=': ', _end='\r\n'):
"""Writes headers to a stream. Constructs payload writer."""
# Chunked response is only for HTTP/1.1 clients or newer
# and there is no Content-Length header is set.
# Do not use chunked responses when the response is guaranteed to
# not have a response body (304, 204).
assert not self.headers_sent, 'headers have been sent already'
self.headers_sent = True
if self.chunked or self.autochunked():
self.writer = self._write_chunked_payload()
self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
elif self.length is not None:
self.writer = self._write_length_payload(self.length)
else:
self.writer = self._write_eof_payload()
next(self.writer)
self._add_default_headers()
# status + headers
headers = self.status_line + ''.join(
[k + _sep + v + _end for k, v in self.headers.items()])
headers = headers.encode('utf-8') + b'\r\n'
self.output_length += len(headers)
self.headers_length = len(headers)
self.transport.write(headers)
def _add_default_headers(self):
# set the connection header
connection = None
if self.upgrade:
connection = 'upgrade'
elif not self.closing if self.keepalive is None else self.keepalive:
if self.version == HttpVersion10:
connection = 'keep-alive'
else:
if self.version == HttpVersion11:
connection = 'close'
if connection is not None:
self.headers[hdrs.CONNECTION] = connection
def write(self, chunk, *,
drain=False, EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
"""Writes chunk of data to a stream by using different writers.
writer uses filter to modify chunk of data.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
assert (isinstance(chunk, (bytes, bytearray)) or
chunk is EOF_MARKER), chunk
size = self.output_length
if self._send_headers and not self.headers_sent:
self.send_headers()
assert self.writer is not None, 'send_headers() is not called.'
if self.filter:
chunk = self.filter.send(chunk)
while chunk not in (EOF_MARKER, EOL_MARKER):
if chunk:
self.writer.send(chunk)
chunk = next(self.filter)
else:
if chunk is not EOF_MARKER:
self.writer.send(chunk)
self._output_size += self.output_length - size
if self._output_size > 64 * 1024:
if drain:
self._output_size = 0
return self.transport.drain()
return ()
def write_eof(self):
self.write(EOF_MARKER)
try:
self.writer.throw(aiohttp.EofStream())
except StopIteration:
pass
return self.transport.drain()
def _write_chunked_payload(self):
"""Write data in chunked transfer encoding."""
while True:
try:
chunk = yield
except aiohttp.EofStream:
self.transport.write(b'0\r\n\r\n')
self.output_length += 5
break
chunk = bytes(chunk)
chunk_len = '{:x}\r\n'.format(len(chunk)).encode('ascii')
self.transport.write(chunk_len + chunk + b'\r\n')
self.output_length += len(chunk_len) + len(chunk) + 2
def _write_length_payload(self, length):
"""Write specified number of bytes to a stream."""
while True:
try:
chunk = yield
except aiohttp.EofStream:
break
if length:
l = len(chunk)
if length >= l:
self.transport.write(chunk)
self.output_length += l
length = length-l
else:
self.transport.write(chunk[:length])
self.output_length += length
length = 0
def _write_eof_payload(self):
while True:
try:
chunk = yield
except aiohttp.EofStream:
break
self.transport.write(chunk)
self.output_length += len(chunk)
@wrap_payload_filter
def add_chunking_filter(self, chunk_size=16*1024, *,
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
"""Split incoming stream into chunks."""
buf = bytearray()
chunk = yield
while True:
if chunk is EOF_MARKER:
if buf:
yield buf
yield EOF_MARKER
else:
buf.extend(chunk)
while len(buf) >= chunk_size:
chunk = bytes(buf[:chunk_size])
del buf[:chunk_size]
yield chunk
chunk = yield EOL_MARKER
@wrap_payload_filter
def add_compression_filter(self, encoding='deflate', *,
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
"""Compress incoming stream with deflate or gzip encoding."""
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
zcomp = zlib.compressobj(wbits=zlib_mode)
chunk = yield
while True:
if chunk is EOF_MARKER:
yield zcomp.flush()
chunk = yield EOF_MARKER
else:
yield zcomp.compress(chunk)
chunk = yield EOL_MARKER
class Response(HttpMessage):
"""Create HTTP response message.
Transport is a socket stream transport. status is a response status code,
status has to be integer value. http_version is a tuple that represents
HTTP version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1
"""
HOP_HEADERS = ()
@staticmethod
def calc_reason(status, *, _RESPONSES=RESPONSES):
record = _RESPONSES.get(status)
if record is not None:
reason = record[0]
else:
reason = str(status)
return reason
def __init__(self, transport, status,
http_version=HttpVersion11, close=False, reason=None):
super().__init__(transport, http_version, close)
self._status = status
if reason is None:
reason = self.calc_reason(status)
self._reason = reason
@property
def status(self):
return self._status
@property
def reason(self):
return self._reason
@reify
def status_line(self):
version = self.version
return 'HTTP/{}.{} {} {}\r\n'.format(
version[0], version[1], self.status, self.reason)
def autochunked(self):
return (self.length is None and
self.version >= HttpVersion11)
def _add_default_headers(self):
super()._add_default_headers()
if hdrs.DATE not in self.headers:
# format_date_time(None) is quite expensive
self.headers.setdefault(hdrs.DATE, format_date_time(None))
self.headers.setdefault(hdrs.SERVER, self.SERVER_SOFTWARE)
class Request(HttpMessage):
HOP_HEADERS = ()
def __init__(self, transport, method, path,
http_version=HttpVersion11, close=False):
# set the default for HTTP 0.9 to be different
# will only be overwritten with keep-alive header
if http_version < HttpVersion10:
close = True
super().__init__(transport, http_version, close)
self._method = method
self._path = path
@property
def method(self):
return self._method
@property
def path(self):
return self._path
@reify
def status_line(self):
return '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(
self.method, self.path, self.version)
def autochunked(self):
return (self.length is None and
self.version >= HttpVersion11 and
self.status not in (304, 204))

@ -1,142 +1,17 @@
import asyncio
import contextlib
import warnings
from collections.abc import Callable
import pytest
from aiohttp.helpers import PY_37, isasyncgenfunction
from aiohttp.web import Application
from .test_utils import (
BaseTestServer,
RawTestServer,
TestClient,
TestServer,
loop_context,
setup_test_loop,
teardown_test_loop,
)
from .test_utils import unused_port as _unused_port
try:
import uvloop
except ImportError: # pragma: no cover
uvloop = None
try:
import tokio
except ImportError: # pragma: no cover
tokio = None
def pytest_addoption(parser): # type: ignore
parser.addoption(
'--aiohttp-fast', action='store_true', default=False,
help='run tests faster by disabling extra checks')
parser.addoption(
'--aiohttp-loop', action='store', default='pyloop',
help='run tests with specific loop: pyloop, uvloop, tokio or all')
parser.addoption(
'--aiohttp-enable-loop-debug', action='store_true', default=False,
help='enable event loop debug mode')
def pytest_fixture_setup(fixturedef): # type: ignore
"""
Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
"""
func = fixturedef.func
if isasyncgenfunction(func):
# async generator fixture
is_async_gen = True
elif asyncio.iscoroutinefunction(func):
# regular async fixture
is_async_gen = False
else:
# not an async fixture, nothing to do
return
strip_request = False
if 'request' not in fixturedef.argnames:
fixturedef.argnames += ('request',)
strip_request = True
def wrapper(*args, **kwargs): # type: ignore
request = kwargs['request']
if strip_request:
del kwargs['request']
# if neither the fixture nor the test use the 'loop' fixture,
# 'getfixturevalue' will fail because the test is not parameterized
# (this can be removed someday if 'loop' is no longer parameterized)
if 'loop' not in request.fixturenames:
raise Exception(
"Asynchronous fixtures must depend on the 'loop' fixture or "
"be used in tests depending from it."
)
_loop = request.getfixturevalue('loop')
if is_async_gen:
# for async generators, we need to advance the generator once,
# then advance it again in a finalizer
gen = func(*args, **kwargs)
def finalizer(): # type: ignore
try:
return _loop.run_until_complete(gen.__anext__())
except StopAsyncIteration: # NOQA
pass
request.addfinalizer(finalizer)
return _loop.run_until_complete(gen.__anext__())
else:
return _loop.run_until_complete(func(*args, **kwargs))
fixturedef.func = wrapper
@pytest.fixture
def fast(request): # type: ignore
"""--fast config option"""
return request.config.getoption('--aiohttp-fast')
@pytest.fixture
def loop_debug(request): # type: ignore
"""--enable-loop-debug config option"""
return request.config.getoption('--aiohttp-enable-loop-debug')
from .test_utils import (TestClient, TestServer, loop_context, setup_test_loop,
teardown_test_loop)
@contextlib.contextmanager
def _runtime_warning_context(): # type: ignore
"""
Context manager which checks for RuntimeWarnings, specifically to
avoid "coroutine 'X' was never awaited" warnings being missed.
If RuntimeWarnings occur in the context a RuntimeError is raised.
"""
with warnings.catch_warnings(record=True) as _warnings:
yield
rw = ['{w.filename}:{w.lineno}:{w.message}'.format(w=w)
for w in _warnings # type: ignore
if w.category == RuntimeWarning]
if rw:
raise RuntimeError('{} Runtime Warning{},\n{}'.format(
len(rw),
'' if len(rw) == 1 else 's',
'\n'.join(rw)
))
@contextlib.contextmanager
def _passthrough_loop_context(loop, fast=False): # type: ignore
"""
setups and tears down a loop unless one is passed in via the loop
argument when it's passed straight through.
"""
def _passthrough_loop_context(loop):
if loop:
# loop already exists, pass it straight through
yield loop
@ -144,10 +19,10 @@ def _passthrough_loop_context(loop, fast=False): # type: ignore
# this shadows loop_context's standard behavior
loop = setup_test_loop()
yield loop
teardown_test_loop(loop, fast=fast)
teardown_test_loop(loop)
def pytest_pycollect_makeitem(collector, name, obj): # type: ignore
def pytest_pycollect_makeitem(collector, name, obj):
"""
Fix pytest collecting for coroutines.
"""
@ -155,198 +30,84 @@ def pytest_pycollect_makeitem(collector, name, obj): # type: ignore
return list(collector._genfunctions(name, obj))
def pytest_pyfunc_call(pyfuncitem): # type: ignore
def pytest_pyfunc_call(pyfuncitem):
"""
Run coroutines in an event loop instead of a normal function call.
"""
fast = pyfuncitem.config.getoption("--aiohttp-fast")
if asyncio.iscoroutinefunction(pyfuncitem.function):
existing_loop = pyfuncitem.funcargs.get('proactor_loop')\
or pyfuncitem.funcargs.get('loop', None)
with _runtime_warning_context():
with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
testargs = {arg: pyfuncitem.funcargs[arg]
for arg in pyfuncitem._fixtureinfo.argnames}
_loop.run_until_complete(pyfuncitem.obj(**testargs))
existing_loop = pyfuncitem.funcargs.get('loop', None)
with _passthrough_loop_context(existing_loop) as _loop:
testargs = {arg: pyfuncitem.funcargs[arg]
for arg in pyfuncitem._fixtureinfo.argnames}
return True
def pytest_generate_tests(metafunc): # type: ignore
if 'loop_factory' not in metafunc.fixturenames:
return
loops = metafunc.config.option.aiohttp_loop
avail_factories = {'pyloop': asyncio.DefaultEventLoopPolicy}
if uvloop is not None: # pragma: no cover
avail_factories['uvloop'] = uvloop.EventLoopPolicy
if tokio is not None: # pragma: no cover
avail_factories['tokio'] = tokio.EventLoopPolicy
if loops == 'all':
loops = 'pyloop,uvloop?,tokio?'
factories = {} # type: ignore
for name in loops.split(','):
required = not name.endswith('?')
name = name.strip(' ?')
if name not in avail_factories: # pragma: no cover
if required:
raise ValueError(
"Unknown loop '%s', available loops: %s" % (
name, list(factories.keys())))
else:
continue
factories[name] = avail_factories[name]
metafunc.parametrize("loop_factory",
list(factories.values()),
ids=list(factories.keys()))
@pytest.fixture
def loop(loop_factory, fast, loop_debug): # type: ignore
"""Return an instance of the event loop."""
policy = loop_factory()
asyncio.set_event_loop_policy(policy)
with loop_context(fast=fast) as _loop:
if loop_debug:
_loop.set_debug(True) # pragma: no cover
asyncio.set_event_loop(_loop)
yield _loop
task = _loop.create_task(pyfuncitem.obj(**testargs))
_loop.run_until_complete(task)
return True
@pytest.fixture
def proactor_loop(): # type: ignore
if not PY_37:
policy = asyncio.get_event_loop_policy()
policy._loop_factory = asyncio.ProactorEventLoop # type: ignore
else:
policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore
asyncio.set_event_loop_policy(policy)
with loop_context(policy.new_event_loop) as _loop:
asyncio.set_event_loop(_loop)
@pytest.yield_fixture
def loop():
with loop_context() as _loop:
yield _loop
@pytest.fixture
def unused_port(aiohttp_unused_port): # type: ignore # pragma: no cover
warnings.warn("Deprecated, use aiohttp_unused_port fixture instead",
DeprecationWarning)
return aiohttp_unused_port
@pytest.fixture
def aiohttp_unused_port(): # type: ignore
"""Return a port that is unused on the current host."""
def unused_port():
return _unused_port
@pytest.fixture
def aiohttp_server(loop): # type: ignore
"""Factory to create a TestServer instance, given an app.
aiohttp_server(app, **kwargs)
"""
@pytest.yield_fixture
def test_server(loop):
servers = []
async def go(app, *, port=None, **kwargs): # type: ignore
server = TestServer(app, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
return server
yield go
async def finalize(): # type: ignore
while servers:
await servers.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def test_server(aiohttp_server): # type: ignore # pragma: no cover
warnings.warn("Deprecated, use aiohttp_server fixture instead",
DeprecationWarning)
return aiohttp_server
@asyncio.coroutine
def go(app, **kwargs):
assert app.loop is loop, \
"Application is attached to other event loop"
@pytest.fixture
def aiohttp_raw_server(loop): # type: ignore
"""Factory to create a RawTestServer instance, given a web handler.
aiohttp_raw_server(handler, **kwargs)
"""
servers = []
async def go(handler, *, port=None, **kwargs): # type: ignore
server = RawTestServer(handler, port=port)
await server.start_server(loop=loop, **kwargs)
server = TestServer(app)
yield from server.start_server(**kwargs)
servers.append(server)
return server
yield go
async def finalize(): # type: ignore
@asyncio.coroutine
def finalize():
while servers:
await servers.pop().close()
yield from servers.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def raw_test_server(aiohttp_raw_server): # type: ignore # pragma: no cover
warnings.warn("Deprecated, use aiohttp_raw_server fixture instead",
DeprecationWarning)
return aiohttp_raw_server
@pytest.fixture
def aiohttp_client(loop): # type: ignore
"""Factory to create a TestClient instance.
aiohttp_client(app, **kwargs)
aiohttp_client(server, **kwargs)
aiohttp_client(raw_server, **kwargs)
"""
@pytest.yield_fixture
def test_client(loop):
clients = []
async def go(__param, *args, server_kwargs=None, **kwargs): # type: ignore
if (isinstance(__param, Callable) and # type: ignore
not isinstance(__param, (Application, BaseTestServer))):
__param = __param(loop, *args, **kwargs)
kwargs = {}
else:
assert not args, "args should be empty"
@asyncio.coroutine
def go(__param, *args, **kwargs):
if isinstance(__param, Application):
server_kwargs = server_kwargs or {}
server = TestServer(__param, loop=loop, **server_kwargs)
client = TestClient(server, loop=loop, **kwargs)
elif isinstance(__param, BaseTestServer):
client = TestClient(__param, loop=loop, **kwargs)
assert not args, "args should be empty"
assert not kwargs, "kwargs should be empty"
assert __param.loop is loop, \
"Application is attached to other event loop"
elif isinstance(__param, TestServer):
assert __param.app.loop is loop, \
"TestServer is attached to other event loop"
else:
raise ValueError("Unknown argument type: %r" % type(__param))
__param = __param(loop, *args, **kwargs)
await client.start_server()
client = TestClient(__param)
yield from client.start_server()
clients.append(client)
return client
yield go
async def finalize(): # type: ignore
@asyncio.coroutine
def finalize():
while clients:
await clients.pop().close()
yield from clients.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def test_client(aiohttp_client): # type: ignore # pragma: no cover
warnings.warn("Deprecated, use aiohttp_client fixture instead",
DeprecationWarning)
return aiohttp_client

@ -1,19 +1,16 @@
import asyncio
import socket
from typing import Any, Dict, List, Optional
from .abc import AbstractResolver
from .helpers import get_running_loop
__all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver')
try:
import aiodns
# aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
except ImportError: # pragma: no cover
aiodns = None
aiodns_default = False
aiodns_default = False
class ThreadedResolver(AbstractResolver):
@ -21,12 +18,14 @@ class ThreadedResolver(AbstractResolver):
concurrent.futures.ThreadPoolExecutor.
"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
self._loop = get_running_loop(loop)
def __init__(self, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
async def resolve(self, host: str, port: int=0,
family: int=socket.AF_INET) -> List[Dict[str, Any]]:
infos = await self._loop.getaddrinfo(
@asyncio.coroutine
def resolve(self, host, port=0, family=socket.AF_INET):
infos = yield from self._loop.getaddrinfo(
host, port, type=socket.SOCK_STREAM, family=family)
hosts = []
@ -39,60 +38,51 @@ class ThreadedResolver(AbstractResolver):
return hosts
async def close(self) -> None:
@asyncio.coroutine
def close(self):
pass
class AsyncResolver(AbstractResolver):
"""Use the `aiodns` package to make asynchronous DNS lookups"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None,
*args: Any, **kwargs: Any) -> None:
def __init__(self, loop=None, *args, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")
self._loop = get_running_loop(loop)
self._loop = loop
self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)
if not hasattr(self._resolver, 'gethostbyname'):
# aiodns 1.1 is not available, fallback to DNSResolver.query
self.resolve = self._resolve_with_query # type: ignore
async def resolve(self, host: str, port: int=0,
family: int=socket.AF_INET) -> List[Dict[str, Any]]:
try:
resp = await self._resolver.gethostbyname(host, family)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
self.resolve = self.resolve_with_query
@asyncio.coroutine
def resolve(self, host, port=0, family=socket.AF_INET):
hosts = []
resp = yield from self._resolver.gethostbyname(host, family)
for address in resp.addresses:
hosts.append(
{'hostname': host,
'host': address, 'port': port,
'family': family, 'proto': 0,
'flags': socket.AI_NUMERICHOST})
if not hosts:
raise OSError("DNS lookup failed")
return hosts
async def _resolve_with_query(
self, host: str, port: int=0,
family: int=socket.AF_INET) -> List[Dict[str, Any]]:
@asyncio.coroutine
def resolve_with_query(self, host, port=0, family=socket.AF_INET):
if family == socket.AF_INET6:
qtype = 'AAAA'
else:
qtype = 'A'
try:
resp = await self._resolver.query(host, qtype)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
hosts = []
resp = yield from self._resolver.query(host, qtype)
for rr in resp:
hosts.append(
{'hostname': host,
@ -100,12 +90,10 @@ class AsyncResolver(AbstractResolver):
'family': family, 'proto': 0,
'flags': socket.AI_NUMERICHOST})
if not hosts:
raise OSError("DNS lookup failed")
return hosts
async def close(self) -> None:
@asyncio.coroutine
def close(self):
return self._resolver.cancel()

@ -0,0 +1,376 @@
"""simple HTTP server."""
import asyncio
import http.server
import socket
import traceback
import warnings
from contextlib import suppress
from html import escape as html_escape
import aiohttp
from aiohttp import errors, hdrs, helpers, streams
from aiohttp.helpers import Timeout, _get_kwarg, ensure_future
from aiohttp.log import access_logger, server_logger
__all__ = ('ServerHttpProtocol',)
RESPONSES = http.server.BaseHTTPRequestHandler.responses
DEFAULT_ERROR_MESSAGE = """
<html>
<head>
<title>{status} {reason}</title>
</head>
<body>
<h1>{status} {reason}</h1>
{message}
</body>
</html>"""
if hasattr(socket, 'SO_KEEPALIVE'):
def tcp_keepalive(server, transport):
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
else:
def tcp_keepalive(server, transport): # pragma: no cover
pass
EMPTY_PAYLOAD = streams.EmptyStreamReader()
class ServerHttpProtocol(aiohttp.StreamProtocol):
"""Simple HTTP protocol implementation.
ServerHttpProtocol handles incoming HTTP request. It reads request line,
request headers and request payload and calls handle_request() method.
By default it always returns with 404 response.
ServerHttpProtocol handles errors in incoming request, like bad
status line, bad headers or incomplete payload. If any error occurs,
connection gets closed.
:param keepalive_timeout: number of seconds before closing
keep-alive connection
:type keepalive_timeout: int or None
:param bool tcp_keepalive: TCP keep-alive is on, default is on
:param int slow_request_timeout: slow request timeout
:param bool debug: enable debug mode
:param logger: custom logger object
:type logger: aiohttp.log.server_logger
:param access_log: custom logging object
:type access_log: aiohttp.log.server_logger
:param str access_log_format: access log format string
:param loop: Optional event loop
:param int max_line_size: Optional maximum header line size
:param int max_field_size: Optional maximum header field size
:param int max_headers: Optional maximum header size
"""
_request_count = 0
_request_handler = None
_reading_request = False
_keepalive = False # keep transport open
def __init__(self, *, loop=None,
keepalive_timeout=75, # NGINX default value is 75 secs
tcp_keepalive=True,
slow_request_timeout=0,
logger=server_logger,
access_log=access_logger,
access_log_format=helpers.AccessLogger.LOG_FORMAT,
debug=False,
max_line_size=8190,
max_headers=32768,
max_field_size=8190,
**kwargs):
# process deprecated params
logger = _get_kwarg(kwargs, 'log', 'logger', logger)
tcp_keepalive = _get_kwarg(kwargs, 'keep_alive_on',
'tcp_keepalive', tcp_keepalive)
keepalive_timeout = _get_kwarg(kwargs, 'keep_alive',
'keepalive_timeout', keepalive_timeout)
slow_request_timeout = _get_kwarg(kwargs, 'timeout',
'slow_request_timeout',
slow_request_timeout)
super().__init__(
loop=loop,
disconnect_error=errors.ClientDisconnectedError, **kwargs)
self._tcp_keepalive = tcp_keepalive
self._keepalive_timeout = keepalive_timeout
self._slow_request_timeout = slow_request_timeout
self._loop = loop if loop is not None else asyncio.get_event_loop()
self._request_prefix = aiohttp.HttpPrefixParser()
self._request_parser = aiohttp.HttpRequestParser(
max_line_size=max_line_size,
max_field_size=max_field_size,
max_headers=max_headers)
self.logger = logger
self.debug = debug
self.access_log = access_log
if access_log:
self.access_logger = helpers.AccessLogger(access_log,
access_log_format)
else:
self.access_logger = None
self._closing = False
@property
def keep_alive_timeout(self):
warnings.warn("Use keepalive_timeout property instead",
DeprecationWarning,
stacklevel=2)
return self._keepalive_timeout
@property
def keepalive_timeout(self):
return self._keepalive_timeout
@asyncio.coroutine
def shutdown(self, timeout=15.0):
"""Worker process is about to exit, we need cleanup everything and
stop accepting requests. It is especially important for keep-alive
connections."""
if self._request_handler is None:
return
self._closing = True
if timeout:
canceller = self._loop.call_later(timeout,
self._request_handler.cancel)
with suppress(asyncio.CancelledError):
yield from self._request_handler
canceller.cancel()
else:
self._request_handler.cancel()
def connection_made(self, transport):
super().connection_made(transport)
self._request_handler = ensure_future(self.start(), loop=self._loop)
if self._tcp_keepalive:
tcp_keepalive(self, transport)
def connection_lost(self, exc):
super().connection_lost(exc)
self._closing = True
if self._request_handler is not None:
self._request_handler.cancel()
def data_received(self, data):
super().data_received(data)
# reading request
if not self._reading_request:
self._reading_request = True
def keep_alive(self, val):
"""Set keep-alive connection mode.
:param bool val: new state.
"""
self._keepalive = val
def log_access(self, message, environ, response, time):
if self.access_logger:
self.access_logger.log(message, environ, response,
self.transport, time)
def log_debug(self, *args, **kw):
if self.debug:
self.logger.debug(*args, **kw)
def log_exception(self, *args, **kw):
self.logger.exception(*args, **kw)
@asyncio.coroutine
def start(self):
"""Start processing of incoming requests.
It reads request line, request headers and request payload, then
calls handle_request() method. Subclass has to override
handle_request(). start() handles various exceptions in request
or response handling. Connection is being closed always unless
keep_alive(True) specified.
"""
reader = self.reader
try:
while not self._closing:
message = None
self._keepalive = False
self._request_count += 1
self._reading_request = False
payload = None
with Timeout(max(self._slow_request_timeout,
self._keepalive_timeout),
loop=self._loop):
# read HTTP request method
prefix = reader.set_parser(self._request_prefix)
yield from prefix.read()
# start reading request
self._reading_request = True
# start slow request timer
# read request headers
httpstream = reader.set_parser(self._request_parser)
message = yield from httpstream.read()
# request may not have payload
try:
content_length = int(
message.headers.get(hdrs.CONTENT_LENGTH, 0))
except ValueError:
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None
if (content_length > 0 or
message.method == 'CONNECT' or
hdrs.SEC_WEBSOCKET_KEY1 in message.headers or
'chunked' in message.headers.get(
hdrs.TRANSFER_ENCODING, '')):
payload = streams.FlowControlStreamReader(
reader, loop=self._loop)
reader.set_parser(
aiohttp.HttpPayloadParser(message), payload)
else:
payload = EMPTY_PAYLOAD
yield from self.handle_request(message, payload)
if payload and not payload.is_eof():
self.log_debug('Uncompleted request.')
self._closing = True
else:
reader.unset_parser()
if not self._keepalive or not self._keepalive_timeout:
self._closing = True
except asyncio.CancelledError:
self.log_debug(
'Request handler cancelled.')
return
except asyncio.TimeoutError:
self.log_debug(
'Request handler timed out.')
return
except errors.ClientDisconnectedError:
self.log_debug(
'Ignored premature client disconnection #1.')
return
except errors.HttpProcessingError as exc:
yield from self.handle_error(exc.code, message,
None, exc, exc.headers,
exc.message)
except Exception as exc:
yield from self.handle_error(500, message, None, exc)
finally:
self._request_handler = None
if self.transport is None:
self.log_debug(
'Ignored premature client disconnection #2.')
else:
self.transport.close()
def handle_error(self, status=500, message=None,
payload=None, exc=None, headers=None, reason=None):
"""Handle errors.
Returns HTTP response with specific status code. Logs additional
information. It always closes current connection."""
now = self._loop.time()
try:
if self.transport is None:
# client has been disconnected during writing.
return ()
if status == 500:
self.log_exception("Error handling request")
try:
if reason is None or reason == '':
reason, msg = RESPONSES[status]
else:
msg = reason
except KeyError:
status = 500
reason, msg = '???', ''
if self.debug and exc is not None:
try:
tb = traceback.format_exc()
tb = html_escape(tb)
msg += '<br><h2>Traceback:</h2>\n<pre>{}</pre>'.format(tb)
except:
pass
html = DEFAULT_ERROR_MESSAGE.format(
status=status, reason=reason, message=msg).encode('utf-8')
response = aiohttp.Response(self.writer, status, close=True)
response.add_header(hdrs.CONTENT_TYPE, 'text/html; charset=utf-8')
response.add_header(hdrs.CONTENT_LENGTH, str(len(html)))
if headers is not None:
for name, value in headers:
response.add_header(name, value)
response.send_headers()
response.write(html)
# disable CORK, enable NODELAY if needed
self.writer.set_tcp_nodelay(True)
drain = response.write_eof()
self.log_access(message, None, response, self._loop.time() - now)
return drain
finally:
self.keep_alive(False)
def handle_request(self, message, payload):
"""Handle a single HTTP request.
Subclass should override this method. By default it always
returns 404 response.
:param message: Request headers
:type message: aiohttp.protocol.HttpRequestParser
:param payload: Request payload
:type payload: aiohttp.streams.FlowControlStreamReader
"""
now = self._loop.time()
response = aiohttp.Response(
self.writer, 404, http_version=message.version, close=True)
body = b'Page Not Found!'
response.add_header(hdrs.CONTENT_TYPE, 'text/plain')
response.add_header(hdrs.CONTENT_LENGTH, str(len(body)))
response.send_headers()
response.write(body)
drain = response.write_eof()
self.keep_alive(False)
self.log_access(message, None, response, self._loop.time() - now)
return drain

@ -1,34 +1,71 @@
from aiohttp.frozenlist import FrozenList
import asyncio
from itertools import count
__all__ = ('Signal',)
class BaseSignal(list):
class Signal(FrozenList):
@asyncio.coroutine
def _send(self, *args, **kwargs):
for receiver in self:
res = receiver(*args, **kwargs)
if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future):
yield from res
def copy(self):
raise NotImplementedError("copy() is forbidden")
def sort(self):
raise NotImplementedError("sort() is forbidden")
class Signal(BaseSignal):
"""Coroutine-based signal implementation.
To connect a callback to a signal, use any list method.
Signals are fired using the send() coroutine, which takes named
Signals are fired using the :meth:`send` coroutine, which takes named
arguments.
"""
__slots__ = ('_owner',)
def __init__(self, owner):
def __init__(self, app):
super().__init__()
self._owner = owner
def __repr__(self):
return '<Signal owner={}, frozen={}, {!r}>'.format(self._owner,
self.frozen,
list(self))
self._app = app
klass = self.__class__
self._name = klass.__module__ + ':' + klass.__qualname__
self._pre = app.on_pre_signal
self._post = app.on_post_signal
async def send(self, *args, **kwargs):
@asyncio.coroutine
def send(self, *args, **kwargs):
"""
Sends data to all registered receivers.
"""
if not self.frozen:
raise RuntimeError("Cannot send non-frozen signal.")
ordinal = None
debug = self._app._debug
if debug:
ordinal = self._pre.ordinal()
yield from self._pre.send(ordinal, self._name, *args, **kwargs)
yield from self._send(*args, **kwargs)
if debug:
yield from self._post.send(ordinal, self._name, *args, **kwargs)
for receiver in self:
await receiver(*args, **kwargs) # type: ignore
class DebugSignal(BaseSignal):
@asyncio.coroutine
def send(self, ordinal, name, *args, **kwargs):
yield from self._send(ordinal, name, *args, **kwargs)
class PreSignal(DebugSignal):
def __init__(self):
super().__init__()
self._counter = count(1)
def ordinal(self):
return next(self._counter)
class PostSignal(DebugSignal):
pass

@ -1,17 +0,0 @@
from typing import Any, Generic, TypeVar
from aiohttp.frozenlist import FrozenList
__all__ = ('Signal',)
_T = TypeVar('_T')
class Signal(FrozenList[_T], Generic[_T]):
def __init__(self, owner: Any) -> None: ...
def __repr__(self) -> str: ...
async def send(self, *args: Any, **kwargs: Any) -> None: ...

@ -1,92 +1,74 @@
import asyncio
import collections
import warnings
from typing import List # noqa
from typing import Awaitable, Callable, Generic, Optional, Tuple, TypeVar
import functools
import sys
import traceback
from .base_protocol import BaseProtocol
from .helpers import BaseTimerContext, set_exception, set_result
from . import helpers
from .log import internal_logger
try: # pragma: no cover
from typing import Deque # noqa
except ImportError:
from typing_extensions import Deque # noqa
__all__ = (
'EMPTY_PAYLOAD', 'EofStream', 'StreamReader', 'DataQueue',
'FlowControlDataQueue')
'EofStream', 'StreamReader', 'DataQueue', 'ChunksQueue',
'FlowControlStreamReader',
'FlowControlDataQueue', 'FlowControlChunksQueue')
DEFAULT_LIMIT = 2 ** 16
PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)
_T = TypeVar('_T')
EOF_MARKER = b''
DEFAULT_LIMIT = 2 ** 16
class EofStream(Exception):
"""eof stream indication."""
class AsyncStreamIterator(Generic[_T]):
def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None:
self.read_func = read_func
if PY_35:
class AsyncStreamIterator:
def __aiter__(self) -> 'AsyncStreamIterator[_T]':
return self
async def __anext__(self) -> _T:
try:
rv = await self.read_func()
except EofStream:
raise StopAsyncIteration # NOQA
if rv == b'':
raise StopAsyncIteration # NOQA
return rv
def __init__(self, read_func):
self.read_func = read_func
def __aiter__(self):
return self
class ChunkTupleAsyncStreamIterator:
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
def __init__(self, stream: 'StreamReader') -> None:
self._stream = stream
def __aiter__(self) -> 'ChunkTupleAsyncStreamIterator':
return self
async def __anext__(self) -> Tuple[bytes, bool]:
rv = await self._stream.readchunk()
if rv == (b'', False):
raise StopAsyncIteration # NOQA
return rv
@asyncio.coroutine
def __anext__(self):
try:
rv = yield from self.read_func()
except EofStream:
raise StopAsyncIteration # NOQA
if rv == EOF_MARKER:
raise StopAsyncIteration # NOQA
return rv
class AsyncStreamReaderMixin:
def __aiter__(self) -> AsyncStreamIterator[bytes]:
return AsyncStreamIterator(self.readline) # type: ignore
def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]:
"""Returns an asynchronous iterator that yields chunks of size n.
if PY_35:
def __aiter__(self):
return AsyncStreamIterator(self.readline)
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(lambda: self.read(n)) # type: ignore
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
def iter_any(self) -> AsyncStreamIterator[bytes]:
"""Returns an asynchronous iterator that yields all the available
data as soon as it is received
def iter_chunked(self, n):
"""Returns an asynchronous iterator that yields chunks of size n.
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(self.readany) # type: ignore
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(lambda: self.read(n))
def iter_chunks(self) -> ChunkTupleAsyncStreamIterator:
"""Returns an asynchronous iterator that yields chunks of data
as they are received by the server. The yielded objects are tuples
of (bytes, bool) as returned by the StreamReader.readchunk method.
def iter_any(self):
"""Returns an asynchronous iterator that yields slices of data
as they come.
Python-3.5 available for Python 3.5+ only
"""
return ChunkTupleAsyncStreamIterator(self) # type: ignore
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(self.readany)
class StreamReader(AsyncStreamReaderMixin):
@ -105,182 +87,127 @@ class StreamReader(AsyncStreamReaderMixin):
total_bytes = 0
def __init__(self, protocol: BaseProtocol,
*, limit: int=DEFAULT_LIMIT,
timer: Optional[BaseTimerContext]=None,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
self._protocol = protocol
self._low_water = limit
self._high_water = limit * 2
def __init__(self, limit=DEFAULT_LIMIT, timeout=None, loop=None):
self._limit = limit
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._size = 0
self._cursor = 0
self._http_chunk_splits = None # type: Optional[List[int]]
self._buffer = collections.deque() # type: Deque[bytes]
self._buffer = collections.deque()
self._buffer_size = 0
self._buffer_offset = 0
self._eof = False
self._waiter = None # type: Optional[asyncio.Future[None]]
self._eof_waiter = None # type: Optional[asyncio.Future[None]]
self._exception = None # type: Optional[BaseException]
self._timer = timer
self._eof_callbacks = [] # type: List[Callable[[], None]]
def __repr__(self) -> str:
info = [self.__class__.__name__]
if self._size:
info.append('%d bytes' % self._size)
self._waiter = None
self._canceller = None
self._eof_waiter = None
self._exception = None
self._timeout = timeout
def __repr__(self):
info = ['StreamReader']
if self._buffer_size:
info.append('%d bytes' % self._buffer_size)
if self._eof:
info.append('eof')
if self._low_water != DEFAULT_LIMIT:
info.append('low=%d high=%d' % (self._low_water, self._high_water))
if self._limit != DEFAULT_LIMIT:
info.append('l=%d' % self._limit)
if self._waiter:
info.append('w=%r' % self._waiter)
if self._exception:
info.append('e=%r' % self._exception)
return '<%s>' % ' '.join(info)
def exception(self) -> Optional[BaseException]:
def exception(self):
return self._exception
def set_exception(self, exc: BaseException) -> None:
def set_exception(self, exc):
self._exception = exc
self._eof_callbacks.clear()
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_exception(waiter, exc)
if not waiter.cancelled():
waiter.set_exception(exc)
waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
set_exception(waiter, exc)
canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()
def on_eof(self, callback: Callable[[], None]) -> None:
if self._eof:
try:
callback()
except Exception:
internal_logger.exception('Exception in eof callback')
else:
self._eof_callbacks.append(callback)
def feed_eof(self) -> None:
def feed_eof(self):
self._eof = True
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
if not waiter.cancelled():
waiter.set_result(True)
canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()
waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
set_result(waiter, None)
for cb in self._eof_callbacks:
try:
cb()
except Exception:
internal_logger.exception('Exception in eof callback')
if not waiter.cancelled():
waiter.set_result(True)
self._eof_callbacks.clear()
def is_eof(self) -> bool:
def is_eof(self):
"""Return True if 'feed_eof' was called."""
return self._eof
def at_eof(self) -> bool:
def at_eof(self):
"""Return True if the buffer is empty and 'feed_eof' was called."""
return self._eof and not self._buffer
async def wait_eof(self) -> None:
@asyncio.coroutine
def wait_eof(self):
if self._eof:
return
assert self._eof_waiter is None
self._eof_waiter = self._loop.create_future()
self._eof_waiter = helpers.create_future(self._loop)
try:
await self._eof_waiter
yield from self._eof_waiter
finally:
self._eof_waiter = None
def unread_data(self, data: bytes) -> None:
def unread_data(self, data):
""" rollback reading some data from stream, inserting it to buffer head.
"""
warnings.warn("unread_data() is deprecated "
"and will be removed in future releases (#3260)",
DeprecationWarning,
stacklevel=2)
if not data:
return
if self._buffer_offset:
self._buffer[0] = self._buffer[0][self._buffer_offset:]
self._buffer_offset = 0
self._size += len(data)
self._cursor -= len(data)
self._buffer.appendleft(data)
self._eof_counter = 0
self._buffer_size += len(data)
# TODO: size is ignored, remove the param later
def feed_data(self, data: bytes, size: int=0) -> None:
def feed_data(self, data):
assert not self._eof, 'feed_data after feed_eof'
if not data:
return
self._size += len(data)
self._buffer.append(data)
self._buffer_size += len(data)
self.total_bytes += len(data)
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
if (self._size > self._high_water and
not self._protocol._reading_paused):
self._protocol.pause_reading()
def begin_http_chunk_receiving(self) -> None:
if self._http_chunk_splits is None:
if self.total_bytes:
raise RuntimeError("Called begin_http_chunk_receiving when"
"some data was already fed")
self._http_chunk_splits = []
def end_http_chunk_receiving(self) -> None:
if self._http_chunk_splits is None:
raise RuntimeError("Called end_chunk_receiving without calling "
"begin_chunk_receiving first")
# self._http_chunk_splits contains logical byte offsets from start of
# the body transfer. Each offset is the offset of the end of a chunk.
# "Logical" means bytes, accessible for a user.
# If no chunks containig logical data were received, current position
# is difinitely zero.
pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0
if self.total_bytes == pos:
# We should not add empty chunks here. So we check for that.
# Note, when chunked + gzip is used, we can receive a chunk
# of compressed data, but that data may not be enough for gzip FSM
# to yield any uncompressed data. That's why current position may
# not change after receiving a chunk.
return
if not waiter.cancelled():
waiter.set_result(False)
self._http_chunk_splits.append(self.total_bytes)
canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()
# wake up readchunk when end of http chunk received
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
async def _wait(self, func_name: str) -> None:
@asyncio.coroutine
def _wait(self, func_name):
# StreamReader uses a future to link the protocol feed_data() method
# to a read coroutine. Running two read coroutines at the same time
# would have an unexpected behaviour. It would not possible to know
@ -288,18 +215,21 @@ class StreamReader(AsyncStreamReaderMixin):
if self._waiter is not None:
raise RuntimeError('%s() called while another coroutine is '
'already waiting for incoming data' % func_name)
waiter = self._waiter = self._loop.create_future()
waiter = self._waiter = helpers.create_future(self._loop)
if self._timeout:
self._canceller = self._loop.call_later(self._timeout,
self.set_exception,
asyncio.TimeoutError())
try:
if self._timer:
with self._timer:
await waiter
else:
await waiter
yield from waiter
finally:
self._waiter = None
if self._canceller is not None:
self._canceller.cancel()
self._canceller = None
async def readline(self) -> bytes:
@asyncio.coroutine
def readline(self):
if self._exception is not None:
raise self._exception
@ -318,18 +248,19 @@ class StreamReader(AsyncStreamReaderMixin):
if ichar:
not_enough = False
if line_size > self._high_water:
if line_size > self._limit:
raise ValueError('Line is too long')
if self._eof:
break
if not_enough:
await self._wait('readline')
yield from self._wait('readline')
return b''.join(line)
async def read(self, n: int=-1) -> bytes:
@asyncio.coroutine
def read(self, n=-1):
if self._exception is not None:
raise self._exception
@ -341,12 +272,13 @@ class StreamReader(AsyncStreamReaderMixin):
if self._eof and not self._buffer:
self._eof_counter = getattr(self, '_eof_counter', 0) + 1
if self._eof_counter > 5:
stack = traceback.format_stack()
internal_logger.warning(
'Multiple access to StreamReader in eof state, '
'might be infinite loop.', stack_info=True)
'might be infinite loop: \n%s', stack)
if not n:
return b''
return EOF_MARKER
if n < 0:
# This used to just loop creating a new waiter hoping to
@ -355,83 +287,50 @@ class StreamReader(AsyncStreamReaderMixin):
# bytes. So just call self.readany() until EOF.
blocks = []
while True:
block = await self.readany()
block = yield from self.readany()
if not block:
break
blocks.append(block)
return b''.join(blocks)
# TODO: should be `if` instead of `while`
# because waiter maybe triggered on chunk end,
# without feeding any data
while not self._buffer and not self._eof:
await self._wait('read')
if not self._buffer and not self._eof:
yield from self._wait('read')
return self._read_nowait(n)
async def readany(self) -> bytes:
@asyncio.coroutine
def readany(self):
if self._exception is not None:
raise self._exception
# TODO: should be `if` instead of `while`
# because waiter maybe triggered on chunk end,
# without feeding any data
while not self._buffer and not self._eof:
await self._wait('readany')
if not self._buffer and not self._eof:
yield from self._wait('readany')
return self._read_nowait(-1)
async def readchunk(self) -> Tuple[bytes, bool]:
"""Returns a tuple of (data, end_of_http_chunk). When chunked transfer
encoding is used, end_of_http_chunk is a boolean indicating if the end
of the data corresponds to the end of a HTTP chunk , otherwise it is
always False.
"""
while True:
if self._exception is not None:
raise self._exception
while self._http_chunk_splits:
pos = self._http_chunk_splits.pop(0)
if pos == self._cursor:
return (b"", True)
if pos > self._cursor:
return (self._read_nowait(pos-self._cursor), True)
internal_logger.warning('Skipping HTTP chunk end due to data '
'consumption beyond chunk boundary')
if self._buffer:
return (self._read_nowait_chunk(-1), False)
# return (self._read_nowait(-1), False)
if self._eof:
# Special case for signifying EOF.
# (b'', True) is not a final return value actually.
return (b'', False)
await self._wait('readchunk')
async def readexactly(self, n: int) -> bytes:
@asyncio.coroutine
def readexactly(self, n):
if self._exception is not None:
raise self._exception
blocks = [] # type: List[bytes]
blocks = []
while n > 0:
block = await self.read(n)
block = yield from self.read(n)
if not block:
partial = b''.join(blocks)
raise asyncio.IncompleteReadError(
raise asyncio.streams.IncompleteReadError(
partial, len(partial) + n)
blocks.append(block)
n -= len(block)
return b''.join(blocks)
def read_nowait(self, n: int=-1) -> bytes:
def read_nowait(self, n=-1):
# default was changed to be consistent with .read(-1)
#
# I believe the most users don't know about the method and
# they are not affected.
assert n is not None, "n should be -1"
if self._exception is not None:
raise self._exception
@ -441,7 +340,7 @@ class StreamReader(AsyncStreamReaderMixin):
return self._read_nowait(n)
def _read_nowait_chunk(self, n: int) -> bytes:
def _read_nowait_chunk(self, n):
first_buffer = self._buffer[0]
offset = self._buffer_offset
if n != -1 and len(first_buffer) - offset > n:
@ -456,20 +355,10 @@ class StreamReader(AsyncStreamReaderMixin):
else:
data = self._buffer.popleft()
self._size -= len(data)
self._cursor += len(data)
chunk_splits = self._http_chunk_splits
# Prevent memory leak: drop useless chunk splits
while chunk_splits and chunk_splits[0] < self._cursor:
chunk_splits.pop(0)
if self._size < self._low_water and self._protocol._reading_paused:
self._protocol.resume_reading()
self._buffer_size -= len(data)
return data
def _read_nowait(self, n: int) -> bytes:
""" Read not more than n bytes, or whole buffer is n == -1 """
def _read_nowait(self, n):
chunks = []
while self._buffer:
@ -480,115 +369,111 @@ class StreamReader(AsyncStreamReaderMixin):
if n == 0:
break
return b''.join(chunks) if chunks else b''
return b''.join(chunks) if chunks else EOF_MARKER
class EmptyStreamReader(AsyncStreamReaderMixin):
def exception(self) -> Optional[BaseException]:
def exception(self):
return None
def set_exception(self, exc: BaseException) -> None:
def set_exception(self, exc):
pass
def on_eof(self, callback: Callable[[], None]) -> None:
try:
callback()
except Exception:
internal_logger.exception('Exception in eof callback')
def feed_eof(self) -> None:
def feed_eof(self):
pass
def is_eof(self) -> bool:
def is_eof(self):
return True
def at_eof(self) -> bool:
def at_eof(self):
return True
async def wait_eof(self) -> None:
@asyncio.coroutine
def wait_eof(self):
return
def feed_data(self, data: bytes, n: int=0) -> None:
def feed_data(self, data):
pass
async def readline(self) -> bytes:
return b''
async def read(self, n: int=-1) -> bytes:
return b''
async def readany(self) -> bytes:
return b''
@asyncio.coroutine
def readline(self):
return EOF_MARKER
async def readchunk(self) -> Tuple[bytes, bool]:
return (b'', True)
@asyncio.coroutine
def read(self, n=-1):
return EOF_MARKER
async def readexactly(self, n: int) -> bytes:
raise asyncio.IncompleteReadError(b'', n)
@asyncio.coroutine
def readany(self):
return EOF_MARKER
def read_nowait(self) -> bytes:
return b''
@asyncio.coroutine
def readexactly(self, n):
raise asyncio.streams.IncompleteReadError(b'', n)
def read_nowait(self):
return EOF_MARKER
EMPTY_PAYLOAD = EmptyStreamReader()
class DataQueue(Generic[_T]):
class DataQueue:
"""DataQueue is a general-purpose blocking queue with one reader."""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
def __init__(self, *, loop=None):
self._loop = loop
self._eof = False
self._waiter = None # type: Optional[asyncio.Future[None]]
self._exception = None # type: Optional[BaseException]
self._waiter = None
self._exception = None
self._size = 0
self._buffer = collections.deque() # type: Deque[Tuple[_T, int]]
def __len__(self) -> int:
return len(self._buffer)
self._buffer = collections.deque()
def is_eof(self) -> bool:
def is_eof(self):
return self._eof
def at_eof(self) -> bool:
def at_eof(self):
return self._eof and not self._buffer
def exception(self) -> Optional[BaseException]:
def exception(self):
return self._exception
def set_exception(self, exc: BaseException) -> None:
self._eof = True
def set_exception(self, exc):
self._exception = exc
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_exception(waiter, exc)
if not waiter.done():
waiter.set_exception(exc)
def feed_data(self, data: _T, size: int=0) -> None:
def feed_data(self, data, size=0):
self._size += size
self._buffer.append((data, size))
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
if not waiter.cancelled():
waiter.set_result(True)
def feed_eof(self) -> None:
def feed_eof(self):
self._eof = True
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
if not waiter.cancelled():
waiter.set_result(False)
async def read(self) -> _T:
@asyncio.coroutine
def read(self):
if not self._buffer and not self._eof:
if self._exception is not None:
raise self._exception
assert not self._waiter
self._waiter = self._loop.create_future()
self._waiter = helpers.create_future(self._loop)
try:
await self._waiter
yield from self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
@ -603,32 +488,185 @@ class DataQueue(Generic[_T]):
else:
raise EofStream
def __aiter__(self) -> AsyncStreamIterator[_T]:
return AsyncStreamIterator(self.read)
if PY_35:
def __aiter__(self):
return AsyncStreamIterator(self.read)
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
class ChunksQueue(DataQueue):
"""Like a :class:`DataQueue`, but for binary chunked data transfer."""
@asyncio.coroutine
def read(self):
try:
return (yield from super().read())
except EofStream:
return EOF_MARKER
readany = read
def maybe_resume(func):
if asyncio.iscoroutinefunction(func):
@asyncio.coroutine
@functools.wraps(func)
def wrapper(self, *args, **kw):
result = yield from func(self, *args, **kw)
self._check_buffer_size()
return result
else:
@functools.wraps(func)
def wrapper(self, *args, **kw):
result = func(self, *args, **kw)
self._check_buffer_size()
return result
return wrapper
class FlowControlStreamReader(StreamReader):
def __init__(self, stream, limit=DEFAULT_LIMIT, *args, **kwargs):
super().__init__(*args, **kwargs)
self._stream = stream
self._b_limit = limit * 2
# resume transport reading
if stream.paused:
try:
self._stream.transport.resume_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = False
def _check_buffer_size(self):
if self._stream.paused:
if self._buffer_size < self._b_limit:
try:
self._stream.transport.resume_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = False
else:
if self._buffer_size > self._b_limit:
try:
self._stream.transport.pause_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = True
def feed_data(self, data, size=0):
has_waiter = self._waiter is not None and not self._waiter.cancelled()
super().feed_data(data)
if (not self._stream.paused and
not has_waiter and self._buffer_size > self._b_limit):
try:
self._stream.transport.pause_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = True
@maybe_resume
@asyncio.coroutine
def read(self, n=-1):
return (yield from super().read(n))
@maybe_resume
@asyncio.coroutine
def readline(self):
return (yield from super().readline())
class FlowControlDataQueue(DataQueue[_T]):
@maybe_resume
@asyncio.coroutine
def readany(self):
return (yield from super().readany())
@maybe_resume
@asyncio.coroutine
def readexactly(self, n):
return (yield from super().readexactly(n))
@maybe_resume
def read_nowait(self, n=-1):
return super().read_nowait(n)
class FlowControlDataQueue(DataQueue):
"""FlowControlDataQueue resumes and pauses an underlying stream.
It is a destination for parsed data."""
def __init__(self, protocol: BaseProtocol, *,
limit: int=DEFAULT_LIMIT,
loop: asyncio.AbstractEventLoop) -> None:
def __init__(self, stream, *, limit=DEFAULT_LIMIT, loop=None):
super().__init__(loop=loop)
self._protocol = protocol
self._stream = stream
self._limit = limit * 2
def feed_data(self, data: _T, size: int=0) -> None:
# resume transport reading
if stream.paused:
try:
self._stream.transport.resume_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = False
def feed_data(self, data, size):
has_waiter = self._waiter is not None and not self._waiter.cancelled()
super().feed_data(data, size)
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
if (not self._stream.paused and
not has_waiter and self._size > self._limit):
try:
self._stream.transport.pause_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = True
@asyncio.coroutine
def read(self):
result = yield from super().read()
if self._stream.paused:
if self._size < self._limit:
try:
self._stream.transport.resume_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = False
else:
if self._size > self._limit:
try:
self._stream.transport.pause_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = True
return result
class FlowControlChunksQueue(FlowControlDataQueue):
async def read(self) -> _T:
@asyncio.coroutine
def read(self):
try:
return await super().read()
finally:
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return (yield from super().read())
except EofStream:
return EOF_MARKER
readany = read

@ -1,63 +0,0 @@
"""Helper methods to tune a TCP connection"""
import asyncio
import socket
from contextlib import suppress
from typing import Optional # noqa
__all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork')
if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK # type: Optional[int]
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH # type: ignore
else: # pragma: no cover
CORK = None
if hasattr(socket, 'SO_KEEPALIVE'):
def tcp_keepalive(transport: asyncio.Transport) -> None:
sock = transport.get_extra_info('socket')
if sock is not None:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
else:
def tcp_keepalive(
transport: asyncio.Transport) -> None: # pragma: no cover
pass
def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None:
sock = transport.get_extra_info('socket')
if sock is None:
return
if sock.family not in (socket.AF_INET, socket.AF_INET6):
return
value = bool(value)
# socket may be closed already, on windows OSError get raised
with suppress(OSError):
sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
def tcp_cork(transport: asyncio.Transport, value: bool) -> None:
sock = transport.get_extra_info('socket')
if CORK is None:
return
if sock is None:
return
if sock.family not in (socket.AF_INET, socket.AF_INET6):
return
value = bool(value)
with suppress(OSError):
sock.setsockopt(
socket.IPPROTO_TCP, CORK, value)

@ -4,157 +4,67 @@ import asyncio
import contextlib
import functools
import gc
import inspect
import socket
import sys
import unittest
from abc import ABC, abstractmethod
from types import TracebackType
from typing import ( # noqa
TYPE_CHECKING,
Any,
Callable,
Iterator,
List,
Optional,
Type,
Union,
)
from unittest import mock
from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL
from multidict import CIMultiDict
import aiohttp
from aiohttp.client import (
ClientResponse,
_RequestContextManager,
_WSRequestContextManager,
)
from . import ClientSession, hdrs
from .abc import AbstractCookieJar
from .client_reqrep import ClientResponse # noqa
from .client_ws import ClientWebSocketResponse # noqa
from .helpers import sentinel
from .http import HttpVersion, RawRequestMessage
from .protocol import HttpVersion, RawRequestMessage
from .signals import Signal
from .web import (
Application,
AppRunner,
BaseRunner,
Request,
Server,
ServerRunner,
SockSite,
UrlMappingMatchInfo,
)
from .web_protocol import _RequestHandler
from .web import Application, Request
if TYPE_CHECKING: # pragma: no cover
from ssl import SSLContext
else:
SSLContext = None
PY_35 = sys.version_info >= (3, 5)
def get_unused_port_socket(host: str) -> socket.socket:
return get_port_socket(host, 0)
def get_port_socket(host: str, port: int) -> socket.socket:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((host, port))
return s
def run_briefly(loop):
@asyncio.coroutine
def once():
pass
t = asyncio.Task(once(), loop=loop)
loop.run_until_complete(t)
def unused_port() -> int:
def unused_port():
"""Return a port that is unused on the current host."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
return s.getsockname()[1]
class BaseTestServer(ABC):
__test__ = False
def __init__(self,
*,
scheme: Union[str, object]=sentinel,
loop: Optional[asyncio.AbstractEventLoop]=None,
host: str='127.0.0.1',
port: Optional[int]=None,
skip_url_asserts: bool=False,
**kwargs: Any) -> None:
self._loop = loop
self.runner = None # type: Optional[BaseRunner]
self._root = None # type: Optional[URL]
class TestServer:
def __init__(self, app, *, scheme="http", host='127.0.0.1'):
self.app = app
self._loop = app.loop
self.port = None
self.server = None
self.handler = None
self._root = None
self.host = host
self.port = port
self._closed = False
self.scheme = scheme
self.skip_url_asserts = skip_url_asserts
self._closed = False
async def start_server(self,
loop: Optional[asyncio.AbstractEventLoop]=None,
**kwargs: Any) -> None:
if self.runner:
@asyncio.coroutine
def start_server(self, **kwargs):
if self.server:
return
self._loop = loop
self._ssl = kwargs.pop('ssl', None)
self.runner = await self._make_runner(**kwargs)
await self.runner.setup()
if not self.port:
self.port = 0
_sock = get_port_socket(self.host, self.port)
self.host, self.port = _sock.getsockname()[:2]
site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
await site.start()
server = site._server
assert server is not None
sockets = server.sockets
assert sockets is not None
self.port = sockets[0].getsockname()[1]
if self.scheme is sentinel:
if self._ssl:
scheme = 'https'
else:
scheme = 'http'
self.scheme = scheme
self._root = URL('{}://{}:{}'.format(self.scheme,
self.host,
self.port))
@abstractmethod # pragma: no cover
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
pass
def make_url(self, path: str) -> URL:
assert self._root is not None
url = URL(path)
if not self.skip_url_asserts:
assert not url.is_absolute()
return self._root.join(url)
else:
return URL(str(self._root) + path)
@property
def started(self) -> bool:
return self.runner is not None
@property
def closed(self) -> bool:
return self._closed
@property
def handler(self) -> Server:
# for backward compatibility
# web.Server instance
runner = self.runner
assert runner is not None
assert runner.server is not None
return runner.server
async def close(self) -> None:
self.port = unused_port()
self._root = '{}://{}:{}'.format(self.scheme, self.host, self.port)
self.handler = self.app.make_handler(**kwargs)
self.server = yield from self._loop.create_server(self.handler,
self.host,
self.port)
def make_url(self, path):
return self._root + path
@asyncio.coroutine
def close(self):
"""Close all fixtures created by the test client.
After that point, the TestClient is no longer usable.
@ -166,115 +76,99 @@ class BaseTestServer(ABC):
exit when used as a context manager.
"""
if self.started and not self.closed:
assert self.runner is not None
await self.runner.cleanup()
if self.server is not None and not self._closed:
self.server.close()
yield from self.server.wait_closed()
yield from self.app.shutdown()
yield from self.handler.finish_connections()
yield from self.app.cleanup()
self._root = None
self.port = None
self._closed = True
def __enter__(self) -> None:
raise TypeError("Use async with instead")
def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
async def __aenter__(self) -> 'BaseTestServer':
await self.start_server(loop=self._loop)
def __enter__(self):
self._loop.run_until_complete(self.start_server())
return self
async def __aexit__(self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
await self.close()
def __exit__(self, exc_type, exc_value, traceback):
self._loop.run_until_complete(self.close())
if PY_35:
@asyncio.coroutine
def __aenter__(self):
yield from self.start_server()
return self
class TestServer(BaseTestServer):
@asyncio.coroutine
def __aexit__(self, exc_type, exc_value, traceback):
yield from self.close()
def __init__(self, app: Application, *,
scheme: Union[str, object]=sentinel,
host: str='127.0.0.1',
port: Optional[int]=None,
**kwargs: Any):
self.app = app
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
return AppRunner(self.app, **kwargs)
class TestClient:
"""
A test client implementation, for a aiohttp.web.Application.
class RawTestServer(BaseTestServer):
:param app: the aiohttp.web application passed to create_test_server
def __init__(self, handler: _RequestHandler, *,
scheme: Union[str, object]=sentinel,
host: str='127.0.0.1',
port: Optional[int]=None,
**kwargs: Any) -> None:
self._handler = handler
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
:type app: aiohttp.web.Application
async def _make_runner(self,
debug: bool=True,
**kwargs: Any) -> ServerRunner:
srv = Server(
self._handler, loop=self._loop, debug=debug, **kwargs)
return ServerRunner(srv, debug=debug, **kwargs)
:param protocol: http or https
:type protocol: str
class TestClient:
TestClient can also be used as a contextmanager, returning
the instance of itself instantiated.
"""
A test client implementation.
To write functional tests for aiohttp based servers.
"""
__test__ = False
def __init__(self, server: BaseTestServer, *,
cookie_jar: Optional[AbstractCookieJar]=None,
loop: Optional[asyncio.AbstractEventLoop]=None,
**kwargs: Any) -> None:
if not isinstance(server, BaseTestServer):
raise TypeError("server must be TestServer "
"instance, found type: %r" % type(server))
self._server = server
self._loop = loop
if cookie_jar is None:
cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop)
self._session = ClientSession(loop=loop,
cookie_jar=cookie_jar,
**kwargs)
def __init__(self, app_or_server, *, scheme=sentinel, host=sentinel):
if isinstance(app_or_server, TestServer):
if scheme is not sentinel or host is not sentinel:
raise ValueError("scheme and host are mutable exclusive "
"with TestServer parameter")
self._server = app_or_server
elif isinstance(app_or_server, Application):
scheme = "http" if scheme is sentinel else scheme
host = '127.0.0.1' if host is sentinel else host
self._server = TestServer(app_or_server,
scheme=scheme, host=host)
else:
raise TypeError("app_or_server should be either web.Application "
"or TestServer instance")
self._loop = self._server.app.loop
self._session = ClientSession(
loop=self._loop,
cookie_jar=aiohttp.CookieJar(unsafe=True,
loop=self._loop))
self._closed = False
self._responses = [] # type: List[ClientResponse]
self._websockets = [] # type: List[ClientWebSocketResponse]
self._responses = []
async def start_server(self) -> None:
await self._server.start_server(loop=self._loop)
@asyncio.coroutine
def start_server(self):
yield from self._server.start_server()
@property
def host(self) -> str:
def app(self):
return self._server.app
@property
def host(self):
return self._server.host
@property
def port(self) -> Optional[int]:
def port(self):
return self._server.port
@property
def server(self) -> BaseTestServer:
return self._server
def handler(self):
return self._server.handler
@property
def app(self) -> Application:
return getattr(self._server, "app", None)
def server(self):
return self._server.server
@property
def session(self) -> ClientSession:
"""An internal aiohttp.ClientSession.
def session(self):
"""A raw handler to the aiohttp.ClientSession.
Unlike the methods on the TestClient, client session requests
do not automatically include the host in the url queried, and
@ -283,91 +177,65 @@ class TestClient:
"""
return self._session
def make_url(self, path: str) -> URL:
def make_url(self, path):
return self._server.make_url(path)
async def _request(self, method: str, path: str,
**kwargs: Any) -> ClientResponse:
resp = await self._session.request(
method, self.make_url(path), **kwargs
)
# save it to close later
self._responses.append(resp)
return resp
@asyncio.coroutine
def request(self, method, path, *args, **kwargs):
"""Routes a request to the http server.
def request(self, method: str, path: str,
**kwargs: Any) -> _RequestContextManager:
"""Routes a request to tested http server.
The interface is identical to aiohttp.ClientSession.request,
The interface is identical to asyncio.ClientSession.request,
except the loop kwarg is overridden by the instance used by the
test server.
application.
"""
return _RequestContextManager(
self._request(method, path, **kwargs)
resp = yield from self._session.request(
method, self.make_url(path), *args, **kwargs
)
# save it to close later
self._responses.append(resp)
return resp
def get(self, path: str, **kwargs: Any) -> _RequestContextManager:
def get(self, path, *args, **kwargs):
"""Perform an HTTP GET request."""
return _RequestContextManager(
self._request(hdrs.METH_GET, path, **kwargs)
)
return self.request(hdrs.METH_GET, path, *args, **kwargs)
def post(self, path: str, **kwargs: Any) -> _RequestContextManager:
def post(self, path, *args, **kwargs):
"""Perform an HTTP POST request."""
return _RequestContextManager(
self._request(hdrs.METH_POST, path, **kwargs)
)
return self.request(hdrs.METH_POST, path, *args, **kwargs)
def options(self, path: str, **kwargs: Any) -> _RequestContextManager:
def options(self, path, *args, **kwargs):
"""Perform an HTTP OPTIONS request."""
return _RequestContextManager(
self._request(hdrs.METH_OPTIONS, path, **kwargs)
)
return self.request(hdrs.METH_OPTIONS, path, *args, **kwargs)
def head(self, path: str, **kwargs: Any) -> _RequestContextManager:
def head(self, path, *args, **kwargs):
"""Perform an HTTP HEAD request."""
return _RequestContextManager(
self._request(hdrs.METH_HEAD, path, **kwargs)
)
return self.request(hdrs.METH_HEAD, path, *args, **kwargs)
def put(self, path: str, **kwargs: Any) -> _RequestContextManager:
def put(self, path, *args, **kwargs):
"""Perform an HTTP PUT request."""
return _RequestContextManager(
self._request(hdrs.METH_PUT, path, **kwargs)
)
return self.request(hdrs.METH_PUT, path, *args, **kwargs)
def patch(self, path: str, **kwargs: Any) -> _RequestContextManager:
def patch(self, path, *args, **kwargs):
"""Perform an HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_PATCH, path, **kwargs)
)
return self.request(hdrs.METH_PATCH, path, *args, **kwargs)
def delete(self, path: str, **kwargs: Any) -> _RequestContextManager:
def delete(self, path, *args, **kwargs):
"""Perform an HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_DELETE, path, **kwargs)
)
return self.request(hdrs.METH_DELETE, path, *args, **kwargs)
def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager:
def ws_connect(self, path, *args, **kwargs):
"""Initiate websocket connection.
The api corresponds to aiohttp.ClientSession.ws_connect.
The api is identical to aiohttp.ClientSession.ws_connect.
"""
return _WSRequestContextManager(
self._ws_connect(path, **kwargs)
return self._session.ws_connect(
self.make_url(path), *args, **kwargs
)
async def _ws_connect(self, path: str,
**kwargs: Any) -> ClientWebSocketResponse:
ws = await self._session.ws_connect(
self.make_url(path), **kwargs)
self._websockets.append(ws)
return ws
async def close(self) -> None:
@asyncio.coroutine
def close(self):
"""Close all fixtures created by the test client.
After that point, the TestClient is no longer usable.
@ -382,31 +250,26 @@ class TestClient:
if not self._closed:
for resp in self._responses:
resp.close()
for ws in self._websockets:
await ws.close()
await self._session.close()
await self._server.close()
yield from self._session.close()
yield from self._server.close()
self._closed = True
def __enter__(self) -> None:
raise TypeError("Use async with instead")
def __enter__(self):
self._loop.run_until_complete(self.start_server())
return self
def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType]) -> None:
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
def __exit__(self, exc_type, exc_value, traceback):
self._loop.run_until_complete(self.close())
async def __aenter__(self) -> 'TestClient':
await self.start_server()
return self
if PY_35:
@asyncio.coroutine
def __aenter__(self):
yield from self.start_server()
return self
async def __aexit__(self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType]) -> None:
await self.close()
@asyncio.coroutine
def __aexit__(self, exc_type, exc_value, traceback):
yield from self.close()
class AioHTTPTestCase(unittest.TestCase):
@ -419,62 +282,35 @@ class AioHTTPTestCase(unittest.TestCase):
* self.loop (asyncio.BaseEventLoop): the event loop in which the
application and server are running.
* self.app (aiohttp.web.Application): the application returned by
self.get_application()
self.get_app()
Note that the TestClient's methods are asynchronous: you have to
execute function on the test client using asynchronous methods.
"""
async def get_application(self) -> Application:
def get_app(self, loop):
"""
This method should be overridden
to return the aiohttp.web.Application
object to test.
:param loop: the event_loop to use
:type loop: asyncio.BaseEventLoop
"""
return self.get_app()
def get_app(self) -> Application:
"""Obsolete method used to constructing web application.
Use .get_application() coroutine instead
"""
raise RuntimeError("Did you forget to define get_application()?")
pass # pragma: no cover
def setUp(self) -> None:
def setUp(self):
self.loop = setup_test_loop()
self.app = self.loop.run_until_complete(self.get_application())
self.server = self.loop.run_until_complete(self.get_server(self.app))
self.client = self.loop.run_until_complete(
self.get_client(self.server))
self.app = self.get_app(self.loop)
self.client = TestClient(self.app)
self.loop.run_until_complete(self.client.start_server())
self.loop.run_until_complete(self.setUpAsync())
async def setUpAsync(self) -> None:
pass
def tearDown(self) -> None:
self.loop.run_until_complete(self.tearDownAsync())
def tearDown(self):
self.loop.run_until_complete(self.client.close())
teardown_test_loop(self.loop)
async def tearDownAsync(self) -> None:
pass
async def get_server(self, app: Application) -> TestServer:
"""Return a TestServer instance."""
return TestServer(app, loop=self.loop)
async def get_client(self, server: TestServer) -> TestClient:
"""Return a TestClient instance."""
return TestClient(server, loop=self.loop)
def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
def unittest_run_loop(func):
"""A decorator dedicated to use with asynchronous methods of an
AioHTTPTestCase.
@ -482,32 +318,25 @@ def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
the self.loop of the AioHTTPTestCase.
"""
@functools.wraps(func, *args, **kwargs)
def new_func(self: Any, *inner_args: Any, **inner_kwargs: Any) -> Any:
return self.loop.run_until_complete(
func(self, *inner_args, **inner_kwargs))
@functools.wraps(func)
def new_func(self):
return self.loop.run_until_complete(func(self))
return new_func
_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
@contextlib.contextmanager
def loop_context(loop_factory: _LOOP_FACTORY=asyncio.new_event_loop,
fast: bool=False) -> Iterator[asyncio.AbstractEventLoop]:
def loop_context(loop_factory=asyncio.new_event_loop):
"""A contextmanager that creates an event_loop, for test purposes.
Handles the creation and cleanup of a test loop.
"""
loop = setup_test_loop(loop_factory)
yield loop
teardown_test_loop(loop, fast=fast)
teardown_test_loop(loop)
def setup_test_loop(
loop_factory: _LOOP_FACTORY=asyncio.new_event_loop
) -> asyncio.AbstractEventLoop:
def setup_test_loop(loop_factory=asyncio.new_event_loop):
"""Create and return an asyncio.BaseEventLoop
instance.
@ -515,62 +344,37 @@ def setup_test_loop(
once they are done with the loop.
"""
loop = loop_factory()
try:
module = loop.__class__.__module__
skip_watcher = 'uvloop' in module
except AttributeError: # pragma: no cover
# Just in case
skip_watcher = True
asyncio.set_event_loop(loop)
if sys.platform != "win32" and not skip_watcher:
policy = asyncio.get_event_loop_policy()
watcher = asyncio.SafeChildWatcher() # type: ignore
watcher.attach_loop(loop)
with contextlib.suppress(NotImplementedError):
policy.set_child_watcher(watcher)
asyncio.set_event_loop(None)
return loop
def teardown_test_loop(loop: asyncio.AbstractEventLoop,
fast: bool=False) -> None:
def teardown_test_loop(loop):
"""Teardown and cleanup an event_loop created
by setup_test_loop.
:param loop: the loop to teardown
:type loop: asyncio.BaseEventLoop
"""
closed = loop.is_closed()
if not closed:
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
if not fast:
gc.collect()
gc.collect()
asyncio.set_event_loop(None)
def _create_app_mock() -> mock.MagicMock:
def get_dict(app: Any, key: str) -> Any:
return app.__app_dict[key]
def set_dict(app: Any, key: str, value: Any) -> None:
app.__app_dict[key] = value
app = mock.MagicMock()
app.__app_dict = {}
app.__getitem__ = get_dict
app.__setitem__ = set_dict
def _create_app_mock():
app = mock.Mock()
app._debug = False
app.on_response_prepare = Signal(app)
app.on_response_prepare.freeze()
return app
def _create_transport(sslcontext: Optional[SSLContext]=None) -> mock.Mock:
def _create_transport(sslcontext=None):
transport = mock.Mock()
def get_extra_info(key: str) -> Optional[SSLContext]:
def get_extra_info(key):
if key == 'sslcontext':
return sslcontext
else:
@ -580,91 +384,102 @@ def _create_transport(sslcontext: Optional[SSLContext]=None) -> mock.Mock:
return transport
def make_mocked_request(method: str, path: str,
headers: Any=None, *,
match_info: Any=sentinel,
version: HttpVersion=HttpVersion(1, 1),
closing: bool=False,
app: Any=None,
writer: Any=sentinel,
protocol: Any=sentinel,
transport: Any=sentinel,
payload: Any=sentinel,
sslcontext: Optional[SSLContext]=None,
client_max_size: int=1024**2,
loop: Any=...) -> Any:
def make_mocked_request(method, path, headers=None, *,
version=HttpVersion(1, 1), closing=False,
app=None,
reader=sentinel,
writer=sentinel,
transport=sentinel,
payload=sentinel,
sslcontext=None,
secure_proxy_ssl_header=None):
"""Creates mocked web.Request testing purposes.
Useful in unit tests, when spinning full web server is overkill or
specific conditions and errors are hard to trigger.
"""
:param method: str, that represents HTTP method, like; GET, POST.
:type method: str
:param path: str, The URL including *PATH INFO* without the host or scheme
:type path: str
:param headers: mapping containing the headers. Can be anything accepted
by the multidict.CIMultiDict constructor.
:type headers: dict, multidict.CIMultiDict, list of pairs
:param version: namedtuple with encoded HTTP version
:type version: aiohttp.protocol.HttpVersion
:param closing: flag indicates that connection should be closed after
response.
:type closing: bool
:param app: the aiohttp.web application attached for fake request
:type app: aiohttp.web.Application
:param reader: object for storing and managing incoming data
:type reader: aiohttp.parsers.StreamParser
:param writer: object for managing outcoming data
:type wirter: aiohttp.parsers.StreamWriter
:param transport: asyncio transport instance
:type transport: asyncio.transports.Transport
task = mock.Mock()
if loop is ...:
loop = mock.Mock()
loop.create_future.return_value = ()
:param payload: raw payload reader object
:type payload: aiohttp.streams.FlowControlStreamReader
:param sslcontext: ssl.SSLContext object, for HTTPS connection
:type sslcontext: ssl.SSLContext
:param secure_proxy_ssl_header: A tuple representing a HTTP header/value
combination that signifies a request is secure.
:type secure_proxy_ssl_header: tuple
"""
if version < HttpVersion(1, 1):
closing = True
if headers:
headers = CIMultiDictProxy(CIMultiDict(headers))
raw_hdrs = tuple(
(k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items())
hdrs = CIMultiDict(headers)
raw_hdrs = [
(k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items()]
else:
headers = CIMultiDictProxy(CIMultiDict())
raw_hdrs = ()
chunked = 'chunked' in headers.get(hdrs.TRANSFER_ENCODING, '').lower()
hdrs = CIMultiDict()
raw_hdrs = []
message = RawRequestMessage(
method, path, version, headers,
raw_hdrs, closing, False, False, chunked, URL(path))
message = RawRequestMessage(method, path, version, hdrs,
raw_hdrs, closing, False)
if app is None:
app = _create_app_mock()
if transport is sentinel:
transport = _create_transport(sslcontext)
if protocol is sentinel:
protocol = mock.Mock()
protocol.transport = transport
if reader is sentinel:
reader = mock.Mock()
if writer is sentinel:
writer = mock.Mock()
writer.write_headers = make_mocked_coro(None)
writer.write = make_mocked_coro(None)
writer.write_eof = make_mocked_coro(None)
writer.drain = make_mocked_coro(None)
writer.transport = transport
protocol.transport = transport
protocol.writer = writer
if transport is sentinel:
transport = _create_transport(sslcontext)
if payload is sentinel:
payload = mock.Mock()
req = Request(message, payload,
protocol, writer, task, loop,
client_max_size=client_max_size)
match_info = UrlMappingMatchInfo(
{} if match_info is sentinel else match_info, mock.Mock())
match_info.add_app(app)
req._match_info = match_info
req = Request(app, message, payload,
transport, reader, writer,
secure_proxy_ssl_header=secure_proxy_ssl_header)
return req
def make_mocked_coro(return_value: Any=sentinel,
raise_exception: Any=sentinel) -> Any:
def make_mocked_coro(return_value=sentinel, raise_exception=sentinel):
"""Creates a coroutine mock."""
async def mock_coro(*args: Any, **kwargs: Any) -> Any:
@asyncio.coroutine
def mock_coro(*args, **kwargs):
if raise_exception is not sentinel:
raise raise_exception
if not inspect.isawaitable(return_value):
return return_value
await return_value
return return_value
return mock.Mock(wraps=mock_coro)

@ -1,387 +0,0 @@
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Callable, Type, Union
import attr
from multidict import CIMultiDict # noqa
from yarl import URL
from .client_reqrep import ClientResponse
from .signals import Signal
if TYPE_CHECKING: # pragma: no cover
from .client import ClientSession # noqa
_SignalArgs = Union[
'TraceRequestStartParams',
'TraceRequestEndParams',
'TraceRequestExceptionParams',
'TraceConnectionQueuedStartParams',
'TraceConnectionQueuedEndParams',
'TraceConnectionCreateStartParams',
'TraceConnectionCreateEndParams',
'TraceConnectionReuseconnParams',
'TraceDnsResolveHostStartParams',
'TraceDnsResolveHostEndParams',
'TraceDnsCacheHitParams',
'TraceDnsCacheMissParams',
'TraceRequestRedirectParams',
'TraceRequestChunkSentParams',
'TraceResponseChunkReceivedParams',
]
_Signal = Signal[Callable[[ClientSession, SimpleNamespace, _SignalArgs],
Awaitable[None]]]
else:
_Signal = Signal
__all__ = (
'TraceConfig', 'TraceRequestStartParams', 'TraceRequestEndParams',
'TraceRequestExceptionParams', 'TraceConnectionQueuedStartParams',
'TraceConnectionQueuedEndParams', 'TraceConnectionCreateStartParams',
'TraceConnectionCreateEndParams', 'TraceConnectionReuseconnParams',
'TraceDnsResolveHostStartParams', 'TraceDnsResolveHostEndParams',
'TraceDnsCacheHitParams', 'TraceDnsCacheMissParams',
'TraceRequestRedirectParams',
'TraceRequestChunkSentParams', 'TraceResponseChunkReceivedParams',
)
class TraceConfig:
"""First-class used to trace requests launched via ClientSession
objects."""
def __init__(
self,
trace_config_ctx_factory: Type[SimpleNamespace]=SimpleNamespace
) -> None:
self._on_request_start = Signal(self) # type: _Signal
self._on_request_chunk_sent = Signal(self) # type: _Signal
self._on_response_chunk_received = Signal(self) # type: _Signal
self._on_request_end = Signal(self) # type: _Signal
self._on_request_exception = Signal(self) # type: _Signal
self._on_request_redirect = Signal(self) # type: _Signal
self._on_connection_queued_start = Signal(self) # type: _Signal
self._on_connection_queued_end = Signal(self) # type: _Signal
self._on_connection_create_start = Signal(self) # type: _Signal
self._on_connection_create_end = Signal(self) # type: _Signal
self._on_connection_reuseconn = Signal(self) # type: _Signal
self._on_dns_resolvehost_start = Signal(self) # type: _Signal
self._on_dns_resolvehost_end = Signal(self) # type: _Signal
self._on_dns_cache_hit = Signal(self) # type: _Signal
self._on_dns_cache_miss = Signal(self) # type: _Signal
self._trace_config_ctx_factory = trace_config_ctx_factory # type: Type[SimpleNamespace] # noqa
def trace_config_ctx(
self,
trace_request_ctx: SimpleNamespace=None
) -> SimpleNamespace: # noqa
""" Return a new trace_config_ctx instance """
return self._trace_config_ctx_factory(
trace_request_ctx=trace_request_ctx)
def freeze(self) -> None:
self._on_request_start.freeze()
self._on_request_chunk_sent.freeze()
self._on_response_chunk_received.freeze()
self._on_request_end.freeze()
self._on_request_exception.freeze()
self._on_request_redirect.freeze()
self._on_connection_queued_start.freeze()
self._on_connection_queued_end.freeze()
self._on_connection_create_start.freeze()
self._on_connection_create_end.freeze()
self._on_connection_reuseconn.freeze()
self._on_dns_resolvehost_start.freeze()
self._on_dns_resolvehost_end.freeze()
self._on_dns_cache_hit.freeze()
self._on_dns_cache_miss.freeze()
@property
def on_request_start(self) -> _Signal:
return self._on_request_start
@property
def on_request_chunk_sent(self) -> _Signal:
return self._on_request_chunk_sent
@property
def on_response_chunk_received(self) -> _Signal:
return self._on_response_chunk_received
@property
def on_request_end(self) -> _Signal:
return self._on_request_end
@property
def on_request_exception(self) -> _Signal:
return self._on_request_exception
@property
def on_request_redirect(self) -> _Signal:
return self._on_request_redirect
@property
def on_connection_queued_start(self) -> _Signal:
return self._on_connection_queued_start
@property
def on_connection_queued_end(self) -> _Signal:
return self._on_connection_queued_end
@property
def on_connection_create_start(self) -> _Signal:
return self._on_connection_create_start
@property
def on_connection_create_end(self) -> _Signal:
return self._on_connection_create_end
@property
def on_connection_reuseconn(self) -> _Signal:
return self._on_connection_reuseconn
@property
def on_dns_resolvehost_start(self) -> _Signal:
return self._on_dns_resolvehost_start
@property
def on_dns_resolvehost_end(self) -> _Signal:
return self._on_dns_resolvehost_end
@property
def on_dns_cache_hit(self) -> _Signal:
return self._on_dns_cache_hit
@property
def on_dns_cache_miss(self) -> _Signal:
return self._on_dns_cache_miss
@attr.s(frozen=True, slots=True)
class TraceRequestStartParams:
""" Parameters sent by the `on_request_start` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type='CIMultiDict[str]')
@attr.s(frozen=True, slots=True)
class TraceRequestChunkSentParams:
""" Parameters sent by the `on_request_chunk_sent` signal"""
chunk = attr.ib(type=bytes)
@attr.s(frozen=True, slots=True)
class TraceResponseChunkReceivedParams:
""" Parameters sent by the `on_response_chunk_received` signal"""
chunk = attr.ib(type=bytes)
@attr.s(frozen=True, slots=True)
class TraceRequestEndParams:
""" Parameters sent by the `on_request_end` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type='CIMultiDict[str]')
response = attr.ib(type=ClientResponse)
@attr.s(frozen=True, slots=True)
class TraceRequestExceptionParams:
""" Parameters sent by the `on_request_exception` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type='CIMultiDict[str]')
exception = attr.ib(type=BaseException)
@attr.s(frozen=True, slots=True)
class TraceRequestRedirectParams:
""" Parameters sent by the `on_request_redirect` signal"""
method = attr.ib(type=str)
url = attr.ib(type=URL)
headers = attr.ib(type='CIMultiDict[str]')
response = attr.ib(type=ClientResponse)
@attr.s(frozen=True, slots=True)
class TraceConnectionQueuedStartParams:
""" Parameters sent by the `on_connection_queued_start` signal"""
@attr.s(frozen=True, slots=True)
class TraceConnectionQueuedEndParams:
""" Parameters sent by the `on_connection_queued_end` signal"""
@attr.s(frozen=True, slots=True)
class TraceConnectionCreateStartParams:
""" Parameters sent by the `on_connection_create_start` signal"""
@attr.s(frozen=True, slots=True)
class TraceConnectionCreateEndParams:
""" Parameters sent by the `on_connection_create_end` signal"""
@attr.s(frozen=True, slots=True)
class TraceConnectionReuseconnParams:
""" Parameters sent by the `on_connection_reuseconn` signal"""
@attr.s(frozen=True, slots=True)
class TraceDnsResolveHostStartParams:
""" Parameters sent by the `on_dns_resolvehost_start` signal"""
host = attr.ib(type=str)
@attr.s(frozen=True, slots=True)
class TraceDnsResolveHostEndParams:
""" Parameters sent by the `on_dns_resolvehost_end` signal"""
host = attr.ib(type=str)
@attr.s(frozen=True, slots=True)
class TraceDnsCacheHitParams:
""" Parameters sent by the `on_dns_cache_hit` signal"""
host = attr.ib(type=str)
@attr.s(frozen=True, slots=True)
class TraceDnsCacheMissParams:
""" Parameters sent by the `on_dns_cache_miss` signal"""
host = attr.ib(type=str)
class Trace:
""" Internal class used to keep together the main dependencies used
at the moment of send a signal."""
def __init__(self,
session: 'ClientSession',
trace_config: TraceConfig,
trace_config_ctx: SimpleNamespace) -> None:
self._trace_config = trace_config
self._trace_config_ctx = trace_config_ctx
self._session = session
async def send_request_start(self,
method: str,
url: URL,
headers: 'CIMultiDict[str]') -> None:
return await self._trace_config.on_request_start.send(
self._session,
self._trace_config_ctx,
TraceRequestStartParams(method, url, headers)
)
async def send_request_chunk_sent(self, chunk: bytes) -> None:
return await self._trace_config.on_request_chunk_sent.send(
self._session,
self._trace_config_ctx,
TraceRequestChunkSentParams(chunk)
)
async def send_response_chunk_received(self, chunk: bytes) -> None:
return await self._trace_config.on_response_chunk_received.send(
self._session,
self._trace_config_ctx,
TraceResponseChunkReceivedParams(chunk)
)
async def send_request_end(self,
method: str,
url: URL,
headers: 'CIMultiDict[str]',
response: ClientResponse) -> None:
return await self._trace_config.on_request_end.send(
self._session,
self._trace_config_ctx,
TraceRequestEndParams(method, url, headers, response)
)
async def send_request_exception(self,
method: str,
url: URL,
headers: 'CIMultiDict[str]',
exception: BaseException) -> None:
return await self._trace_config.on_request_exception.send(
self._session,
self._trace_config_ctx,
TraceRequestExceptionParams(method, url, headers, exception)
)
async def send_request_redirect(self,
method: str,
url: URL,
headers: 'CIMultiDict[str]',
response: ClientResponse) -> None:
return await self._trace_config._on_request_redirect.send(
self._session,
self._trace_config_ctx,
TraceRequestRedirectParams(method, url, headers, response)
)
async def send_connection_queued_start(self) -> None:
return await self._trace_config.on_connection_queued_start.send(
self._session,
self._trace_config_ctx,
TraceConnectionQueuedStartParams()
)
async def send_connection_queued_end(self) -> None:
return await self._trace_config.on_connection_queued_end.send(
self._session,
self._trace_config_ctx,
TraceConnectionQueuedEndParams()
)
async def send_connection_create_start(self) -> None:
return await self._trace_config.on_connection_create_start.send(
self._session,
self._trace_config_ctx,
TraceConnectionCreateStartParams()
)
async def send_connection_create_end(self) -> None:
return await self._trace_config.on_connection_create_end.send(
self._session,
self._trace_config_ctx,
TraceConnectionCreateEndParams()
)
async def send_connection_reuseconn(self) -> None:
return await self._trace_config.on_connection_reuseconn.send(
self._session,
self._trace_config_ctx,
TraceConnectionReuseconnParams()
)
async def send_dns_resolvehost_start(self, host: str) -> None:
return await self._trace_config.on_dns_resolvehost_start.send(
self._session,
self._trace_config_ctx,
TraceDnsResolveHostStartParams(host)
)
async def send_dns_resolvehost_end(self, host: str) -> None:
return await self._trace_config.on_dns_resolvehost_end.send(
self._session,
self._trace_config_ctx,
TraceDnsResolveHostEndParams(host)
)
async def send_dns_cache_hit(self, host: str) -> None:
return await self._trace_config.on_dns_cache_hit.send(
self._session,
self._trace_config_ctx,
TraceDnsCacheHitParams(host)
)
async def send_dns_cache_miss(self, host: str) -> None:
return await self._trace_config.on_dns_cache_miss.send(
self._session,
self._trace_config_ctx,
TraceDnsCacheMissParams(host)
)

@ -1,53 +0,0 @@
import json
import os # noqa
import pathlib # noqa
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Mapping,
Tuple,
Union,
)
from multidict import (
CIMultiDict,
CIMultiDictProxy,
MultiDict,
MultiDictProxy,
istr,
)
from yarl import URL
DEFAULT_JSON_ENCODER = json.dumps
DEFAULT_JSON_DECODER = json.loads
if TYPE_CHECKING: # pragma: no cover
_CIMultiDict = CIMultiDict[str]
_CIMultiDictProxy = CIMultiDictProxy[str]
_MultiDict = MultiDict[str]
_MultiDictProxy = MultiDictProxy[str]
from http.cookies import BaseCookie # noqa
else:
_CIMultiDict = CIMultiDict
_CIMultiDictProxy = CIMultiDictProxy
_MultiDict = MultiDict
_MultiDictProxy = MultiDictProxy
Byteish = Union[bytes, bytearray, memoryview]
JSONEncoder = Callable[[Any], str]
JSONDecoder = Callable[[str], Any]
LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict,
_CIMultiDictProxy]
RawHeaders = Tuple[Tuple[bytes, bytes], ...]
StrOrURL = Union[str, URL]
LooseCookies = Union[Iterable[Tuple[str, 'BaseCookie[str]']],
Mapping[str, 'BaseCookie[str]'], 'BaseCookie[str]']
if sys.version_info >= (3, 6):
PathLike = Union[str, 'os.PathLike[str]']
else:
PathLike = Union[str, pathlib.PurePath]

@ -1,446 +1,333 @@
import asyncio
import logging
import socket
import sys
import warnings
from argparse import ArgumentParser
from collections.abc import Iterable
from importlib import import_module
from typing import Any, Awaitable, Callable, List, Optional, Type, Union, cast
from .abc import AbstractAccessLogger
from .helpers import all_tasks
from .log import access_logger
from .web_app import Application as Application
from .web_app import CleanupError as CleanupError
from .web_exceptions import HTTPAccepted as HTTPAccepted
from .web_exceptions import HTTPBadGateway as HTTPBadGateway
from .web_exceptions import HTTPBadRequest as HTTPBadRequest
from .web_exceptions import HTTPClientError as HTTPClientError
from .web_exceptions import HTTPConflict as HTTPConflict
from .web_exceptions import HTTPCreated as HTTPCreated
from .web_exceptions import HTTPError as HTTPError
from .web_exceptions import HTTPException as HTTPException
from .web_exceptions import HTTPExpectationFailed as HTTPExpectationFailed
from .web_exceptions import HTTPFailedDependency as HTTPFailedDependency
from .web_exceptions import HTTPForbidden as HTTPForbidden
from .web_exceptions import HTTPFound as HTTPFound
from .web_exceptions import HTTPGatewayTimeout as HTTPGatewayTimeout
from .web_exceptions import HTTPGone as HTTPGone
from .web_exceptions import HTTPInsufficientStorage as HTTPInsufficientStorage
from .web_exceptions import HTTPInternalServerError as HTTPInternalServerError
from .web_exceptions import HTTPLengthRequired as HTTPLengthRequired
from .web_exceptions import HTTPMethodNotAllowed as HTTPMethodNotAllowed
from .web_exceptions import HTTPMisdirectedRequest as HTTPMisdirectedRequest
from .web_exceptions import HTTPMovedPermanently as HTTPMovedPermanently
from .web_exceptions import HTTPMultipleChoices as HTTPMultipleChoices
from .web_exceptions import (
HTTPNetworkAuthenticationRequired as HTTPNetworkAuthenticationRequired,
)
from .web_exceptions import HTTPNoContent as HTTPNoContent
from .web_exceptions import (
HTTPNonAuthoritativeInformation as HTTPNonAuthoritativeInformation,
)
from .web_exceptions import HTTPNotAcceptable as HTTPNotAcceptable
from .web_exceptions import HTTPNotExtended as HTTPNotExtended
from .web_exceptions import HTTPNotFound as HTTPNotFound
from .web_exceptions import HTTPNotImplemented as HTTPNotImplemented
from .web_exceptions import HTTPNotModified as HTTPNotModified
from .web_exceptions import HTTPOk as HTTPOk
from .web_exceptions import HTTPPartialContent as HTTPPartialContent
from .web_exceptions import HTTPPaymentRequired as HTTPPaymentRequired
from .web_exceptions import HTTPPermanentRedirect as HTTPPermanentRedirect
from .web_exceptions import HTTPPreconditionFailed as HTTPPreconditionFailed
from .web_exceptions import (
HTTPPreconditionRequired as HTTPPreconditionRequired,
)
from .web_exceptions import (
HTTPProxyAuthenticationRequired as HTTPProxyAuthenticationRequired,
)
from .web_exceptions import HTTPRedirection as HTTPRedirection
from .web_exceptions import (
HTTPRequestEntityTooLarge as HTTPRequestEntityTooLarge,
)
from .web_exceptions import (
HTTPRequestHeaderFieldsTooLarge as HTTPRequestHeaderFieldsTooLarge,
)
from .web_exceptions import (
HTTPRequestRangeNotSatisfiable as HTTPRequestRangeNotSatisfiable,
)
from .web_exceptions import HTTPRequestTimeout as HTTPRequestTimeout
from .web_exceptions import HTTPRequestURITooLong as HTTPRequestURITooLong
from .web_exceptions import HTTPResetContent as HTTPResetContent
from .web_exceptions import HTTPSeeOther as HTTPSeeOther
from .web_exceptions import HTTPServerError as HTTPServerError
from .web_exceptions import HTTPServiceUnavailable as HTTPServiceUnavailable
from .web_exceptions import HTTPSuccessful as HTTPSuccessful
from .web_exceptions import HTTPTemporaryRedirect as HTTPTemporaryRedirect
from .web_exceptions import HTTPTooManyRequests as HTTPTooManyRequests
from .web_exceptions import HTTPUnauthorized as HTTPUnauthorized
from .web_exceptions import (
HTTPUnavailableForLegalReasons as HTTPUnavailableForLegalReasons,
)
from .web_exceptions import HTTPUnprocessableEntity as HTTPUnprocessableEntity
from .web_exceptions import (
HTTPUnsupportedMediaType as HTTPUnsupportedMediaType,
)
from .web_exceptions import HTTPUpgradeRequired as HTTPUpgradeRequired
from .web_exceptions import HTTPUseProxy as HTTPUseProxy
from .web_exceptions import (
HTTPVariantAlsoNegotiates as HTTPVariantAlsoNegotiates,
)
from .web_exceptions import HTTPVersionNotSupported as HTTPVersionNotSupported
from .web_fileresponse import FileResponse as FileResponse
from .web_log import AccessLogger
from .web_middlewares import middleware as middleware
from .web_middlewares import (
normalize_path_middleware as normalize_path_middleware,
)
from .web_protocol import PayloadAccessError as PayloadAccessError
from .web_protocol import RequestHandler as RequestHandler
from .web_protocol import RequestPayloadError as RequestPayloadError
from .web_request import BaseRequest as BaseRequest
from .web_request import FileField as FileField
from .web_request import Request as Request
from .web_response import ContentCoding as ContentCoding
from .web_response import Response as Response
from .web_response import StreamResponse as StreamResponse
from .web_response import json_response as json_response
from .web_routedef import AbstractRouteDef as AbstractRouteDef
from .web_routedef import RouteDef as RouteDef
from .web_routedef import RouteTableDef as RouteTableDef
from .web_routedef import StaticDef as StaticDef
from .web_routedef import delete as delete
from .web_routedef import get as get
from .web_routedef import head as head
from .web_routedef import options as options
from .web_routedef import patch as patch
from .web_routedef import post as post
from .web_routedef import put as put
from .web_routedef import route as route
from .web_routedef import static as static
from .web_routedef import view as view
from .web_runner import AppRunner as AppRunner
from .web_runner import BaseRunner as BaseRunner
from .web_runner import BaseSite as BaseSite
from .web_runner import GracefulExit as GracefulExit
from .web_runner import NamedPipeSite as NamedPipeSite
from .web_runner import ServerRunner as ServerRunner
from .web_runner import SockSite as SockSite
from .web_runner import TCPSite as TCPSite
from .web_runner import UnixSite as UnixSite
from .web_server import Server as Server
from .web_urldispatcher import AbstractResource as AbstractResource
from .web_urldispatcher import AbstractRoute as AbstractRoute
from .web_urldispatcher import DynamicResource as DynamicResource
from .web_urldispatcher import PlainResource as PlainResource
from .web_urldispatcher import Resource as Resource
from .web_urldispatcher import ResourceRoute as ResourceRoute
from .web_urldispatcher import StaticResource as StaticResource
from .web_urldispatcher import UrlDispatcher as UrlDispatcher
from .web_urldispatcher import UrlMappingMatchInfo as UrlMappingMatchInfo
from .web_urldispatcher import View as View
from .web_ws import WebSocketReady as WebSocketReady
from .web_ws import WebSocketResponse as WebSocketResponse
from .web_ws import WSMsgType as WSMsgType
__all__ = (
# web_app
'Application',
'CleanupError',
# web_exceptions
'HTTPAccepted',
'HTTPBadGateway',
'HTTPBadRequest',
'HTTPClientError',
'HTTPConflict',
'HTTPCreated',
'HTTPError',
'HTTPException',
'HTTPExpectationFailed',
'HTTPFailedDependency',
'HTTPForbidden',
'HTTPFound',
'HTTPGatewayTimeout',
'HTTPGone',
'HTTPInsufficientStorage',
'HTTPInternalServerError',
'HTTPLengthRequired',
'HTTPMethodNotAllowed',
'HTTPMisdirectedRequest',
'HTTPMovedPermanently',
'HTTPMultipleChoices',
'HTTPNetworkAuthenticationRequired',
'HTTPNoContent',
'HTTPNonAuthoritativeInformation',
'HTTPNotAcceptable',
'HTTPNotExtended',
'HTTPNotFound',
'HTTPNotImplemented',
'HTTPNotModified',
'HTTPOk',
'HTTPPartialContent',
'HTTPPaymentRequired',
'HTTPPermanentRedirect',
'HTTPPreconditionFailed',
'HTTPPreconditionRequired',
'HTTPProxyAuthenticationRequired',
'HTTPRedirection',
'HTTPRequestEntityTooLarge',
'HTTPRequestHeaderFieldsTooLarge',
'HTTPRequestRangeNotSatisfiable',
'HTTPRequestTimeout',
'HTTPRequestURITooLong',
'HTTPResetContent',
'HTTPSeeOther',
'HTTPServerError',
'HTTPServiceUnavailable',
'HTTPSuccessful',
'HTTPTemporaryRedirect',
'HTTPTooManyRequests',
'HTTPUnauthorized',
'HTTPUnavailableForLegalReasons',
'HTTPUnprocessableEntity',
'HTTPUnsupportedMediaType',
'HTTPUpgradeRequired',
'HTTPUseProxy',
'HTTPVariantAlsoNegotiates',
'HTTPVersionNotSupported',
# web_fileresponse
'FileResponse',
# web_middlewares
'middleware',
'normalize_path_middleware',
# web_protocol
'PayloadAccessError',
'RequestHandler',
'RequestPayloadError',
# web_request
'BaseRequest',
'FileField',
'Request',
# web_response
'ContentCoding',
'Response',
'StreamResponse',
'json_response',
# web_routedef
'AbstractRouteDef',
'RouteDef',
'RouteTableDef',
'StaticDef',
'delete',
'get',
'head',
'options',
'patch',
'post',
'put',
'route',
'static',
'view',
# web_runner
'AppRunner',
'BaseRunner',
'BaseSite',
'GracefulExit',
'ServerRunner',
'SockSite',
'TCPSite',
'UnixSite',
'NamedPipeSite',
# web_server
'Server',
# web_urldispatcher
'AbstractResource',
'AbstractRoute',
'DynamicResource',
'PlainResource',
'Resource',
'ResourceRoute',
'StaticResource',
'UrlDispatcher',
'UrlMappingMatchInfo',
'View',
# web_ws
'WebSocketReady',
'WebSocketResponse',
'WSMsgType',
# web
'run_app',
)
try:
from ssl import SSLContext
except ImportError: # pragma: no cover
SSLContext = Any # type: ignore
async def _run_app(app: Union[Application, Awaitable[Application]], *,
host: Optional[str]=None,
port: Optional[int]=None,
path: Optional[str]=None,
sock: Optional[socket.socket]=None,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
print: Callable[..., None]=print,
backlog: int=128,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log_format: str=AccessLogger.LOG_FORMAT,
access_log: Optional[logging.Logger]=access_logger,
handle_signals: bool=True,
reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
# A internal functio to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app # type: ignore
app = cast(Application, app)
runner = AppRunner(app, handle_signals=handle_signals,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log)
await runner.setup()
sites = [] # type: List[BaseSite]
try:
if host is not None:
if isinstance(host, (str, bytes, bytearray, memoryview)):
sites.append(TCPSite(runner, host, port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port))
else:
for h in host:
sites.append(TCPSite(runner, h, port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port))
elif path is None and sock is None or port is not None:
sites.append(TCPSite(runner, port=port,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context, backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port))
if path is not None:
if isinstance(path, (str, bytes, bytearray, memoryview)):
sites.append(UnixSite(runner, path,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog))
else:
for p in path:
sites.append(UnixSite(runner, p,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog))
if sock is not None:
if not isinstance(sock, Iterable):
sites.append(SockSite(runner, sock,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog))
else:
for s in sock:
sites.append(SockSite(runner, s,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog))
for site in sites:
await site.start()
if print: # pragma: no branch
names = sorted(str(s.name) for s in runner.sites)
print("======== Running on {} ========\n"
"(Press CTRL+C to quit)".format(', '.join(names)))
while True:
await asyncio.sleep(3600) # sleep forever by 1 hour intervals
finally:
await runner.cleanup()
def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(
asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})
def run_app(app: Union[Application, Awaitable[Application]], *,
host: Optional[str]=None,
port: Optional[int]=None,
path: Optional[str]=None,
sock: Optional[socket.socket]=None,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
print: Callable[..., None]=print,
backlog: int=128,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log_format: str=AccessLogger.LOG_FORMAT,
access_log: Optional[logging.Logger]=access_logger,
handle_signals: bool=True,
reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
from . import hdrs, web_exceptions, web_reqrep, web_urldispatcher, web_ws
from .abc import AbstractMatchInfo, AbstractRouter
from .helpers import sentinel
from .log import web_logger
from .protocol import HttpVersion # noqa
from .server import ServerHttpProtocol
from .signals import PostSignal, PreSignal, Signal
from .web_exceptions import * # noqa
from .web_reqrep import * # noqa
from .web_urldispatcher import * # noqa
from .web_ws import * # noqa
__all__ = (web_reqrep.__all__ +
web_exceptions.__all__ +
web_urldispatcher.__all__ +
web_ws.__all__ +
('Application', 'RequestHandler',
'RequestHandlerFactory', 'HttpVersion',
'MsgType'))
class RequestHandler(ServerHttpProtocol):
_meth = 'none'
_path = 'none'
def __init__(self, manager, app, router, *,
secure_proxy_ssl_header=None, **kwargs):
super().__init__(**kwargs)
self._manager = manager
self._app = app
self._router = router
self._middlewares = app.middlewares
self._secure_proxy_ssl_header = secure_proxy_ssl_header
def __repr__(self):
return "<{} {}:{} {}>".format(
self.__class__.__name__, self._meth, self._path,
'connected' if self.transport is not None else 'disconnected')
def connection_made(self, transport):
super().connection_made(transport)
self._manager.connection_made(self, transport)
def connection_lost(self, exc):
self._manager.connection_lost(self, exc)
super().connection_lost(exc)
@asyncio.coroutine
def handle_request(self, message, payload):
self._manager._requests_count += 1
if self.access_log:
now = self._loop.time()
app = self._app
request = web_reqrep.Request(
app, message, payload,
self.transport, self.reader, self.writer,
secure_proxy_ssl_header=self._secure_proxy_ssl_header)
self._meth = request.method
self._path = request.path
try:
match_info = yield from self._router.resolve(request)
assert isinstance(match_info, AbstractMatchInfo), match_info
resp = None
request._match_info = match_info
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = (
yield from match_info.expect_handler(request))
if resp is None:
handler = match_info.handler
for factory in reversed(self._middlewares):
handler = yield from factory(app, handler)
resp = yield from handler(request)
assert isinstance(resp, web_reqrep.StreamResponse), \
("Handler {!r} should return response instance, "
"got {!r} [middlewares {!r}]").format(
match_info.handler, type(resp), self._middlewares)
except web_exceptions.HTTPException as exc:
resp = exc
resp_msg = yield from resp.prepare(request)
yield from resp.write_eof()
# notify server about keep-alive
self.keep_alive(resp.keep_alive)
# log access
if self.access_log:
self.log_access(message, None, resp_msg, self._loop.time() - now)
# for repr
self._meth = 'none'
self._path = 'none'
class RequestHandlerFactory:
def __init__(self, app, router, *,
handler=RequestHandler, loop=None,
secure_proxy_ssl_header=None, **kwargs):
self._app = app
self._router = router
self._handler = handler
self._loop = loop
self._connections = {}
self._secure_proxy_ssl_header = secure_proxy_ssl_header
self._kwargs = kwargs
self._kwargs.setdefault('logger', app.logger)
self._requests_count = 0
@property
def requests_count(self):
"""Number of processed requests."""
return self._requests_count
@property
def secure_proxy_ssl_header(self):
return self._secure_proxy_ssl_header
@property
def connections(self):
return list(self._connections.keys())
def connection_made(self, handler, transport):
self._connections[handler] = transport
def connection_lost(self, handler, exc=None):
if handler in self._connections:
del self._connections[handler]
@asyncio.coroutine
def finish_connections(self, timeout=None):
coros = [conn.shutdown(timeout) for conn in self._connections]
yield from asyncio.gather(*coros, loop=self._loop)
self._connections.clear()
def __call__(self):
return self._handler(
self, self._app, self._router, loop=self._loop,
secure_proxy_ssl_header=self._secure_proxy_ssl_header,
**self._kwargs)
class Application(dict):
def __init__(self, *, logger=web_logger, loop=None,
router=None, handler_factory=RequestHandlerFactory,
middlewares=(), debug=False):
if loop is None:
loop = asyncio.get_event_loop()
if router is None:
router = web_urldispatcher.UrlDispatcher()
assert isinstance(router, AbstractRouter), router
self._debug = debug
self._router = router
self._handler_factory = handler_factory
self._loop = loop
self.logger = logger
self._middlewares = list(middlewares)
self._on_pre_signal = PreSignal()
self._on_post_signal = PostSignal()
self._on_response_prepare = Signal(self)
self._on_startup = Signal(self)
self._on_shutdown = Signal(self)
self._on_cleanup = Signal(self)
@property
def debug(self):
return self._debug
@property
def on_response_prepare(self):
return self._on_response_prepare
@property
def on_pre_signal(self):
return self._on_pre_signal
@property
def on_post_signal(self):
return self._on_post_signal
@property
def on_startup(self):
return self._on_startup
@property
def on_shutdown(self):
return self._on_shutdown
@property
def on_cleanup(self):
return self._on_cleanup
@property
def router(self):
return self._router
@property
def loop(self):
return self._loop
@property
def middlewares(self):
return self._middlewares
def make_handler(self, **kwargs):
debug = kwargs.pop('debug', sentinel)
if debug is not sentinel:
warnings.warn(
"`debug` parameter is deprecated. "
"Use Application's debug mode instead", DeprecationWarning)
if debug != self.debug:
raise ValueError(
"The value of `debug` parameter conflicts with the debug "
"settings of the `Application` instance. The "
"application's debug mode setting should be used instead "
"as a single point to setup a debug mode. For more "
"information please check "
"http://aiohttp.readthedocs.io/en/stable/"
"web_reference.html#aiohttp.web.Application"
)
return self._handler_factory(self, self.router, debug=self.debug,
loop=self.loop, **kwargs)
@asyncio.coroutine
def startup(self):
"""Causes on_startup signal
Should be called in the event loop along with the request handler.
"""
yield from self.on_startup.send(self)
@asyncio.coroutine
def shutdown(self):
"""Causes on_shutdown signal
Should be called before cleanup()
"""
yield from self.on_shutdown.send(self)
@asyncio.coroutine
def cleanup(self):
"""Causes on_cleanup signal
Should be called after shutdown()
"""
yield from self.on_cleanup.send(self)
@asyncio.coroutine
def finish(self):
"""Finalize an application.
Deprecated alias for .cleanup()
"""
warnings.warn("Use .cleanup() instead", DeprecationWarning)
yield from self.cleanup()
def register_on_finish(self, func, *args, **kwargs):
warnings.warn("Use .on_cleanup.append() instead", DeprecationWarning)
self.on_cleanup.append(lambda app: func(app, *args, **kwargs))
def copy(self):
raise NotImplementedError
def __call__(self):
"""gunicorn compatibility"""
return self
def __repr__(self):
return "<Application>"
def run_app(app, *, host='0.0.0.0', port=None,
shutdown_timeout=60.0, ssl_context=None,
print=print, backlog=128):
"""Run an app locally"""
loop = asyncio.get_event_loop()
if port is None:
if not ssl_context:
port = 8080
else:
port = 8443
loop = app.loop
# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log and access_log.name == 'aiohttp.access':
if access_log.level == logging.NOTSET:
access_log.setLevel(logging.DEBUG)
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())
handler = app.make_handler()
server = loop.create_server(handler, host, port, ssl=ssl_context,
backlog=backlog)
srv, startup_res = loop.run_until_complete(asyncio.gather(server,
app.startup(),
loop=loop))
scheme = 'https' if ssl_context else 'http'
print("======== Running on {scheme}://{host}:{port}/ ========\n"
"(Press CTRL+C to quit)".format(
scheme=scheme, host=host, port=port))
try:
loop.run_until_complete(_run_app(app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port))
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
loop.run_forever()
except KeyboardInterrupt: # pragma: no cover
pass
finally:
_cancel_all_tasks(loop)
if sys.version_info >= (3, 6): # don't use PY_36 to pass mypy
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
srv.close()
loop.run_until_complete(srv.wait_closed())
loop.run_until_complete(app.shutdown())
loop.run_until_complete(handler.finish_connections(shutdown_timeout))
loop.run_until_complete(app.cleanup())
loop.close()
def main(argv: List[str]) -> None:
def main(argv):
arg_parser = ArgumentParser(
description="aiohttp.web Application server",
prog="aiohttp.web"
@ -462,11 +349,6 @@ def main(argv: List[str]) -> None:
type=int,
default="8080"
)
arg_parser.add_argument(
"-U", "--path",
help="Unix file system path to serve on. Specifying a path will cause "
"hostname and port arguments to be ignored.",
)
args, extra_argv = arg_parser.parse_known_args(argv)
# Import logic
@ -479,24 +361,16 @@ def main(argv: List[str]) -> None:
arg_parser.error("relative module names not supported")
try:
module = import_module(mod_str)
except ImportError as ex:
arg_parser.error("unable to import %s: %s" % (mod_str, ex))
except ImportError:
arg_parser.error("module %r not found" % mod_str)
try:
func = getattr(module, func_str)
except AttributeError:
arg_parser.error("module %r has no attribute %r" % (mod_str, func_str))
# Compatibility logic
if args.path is not None and not hasattr(socket, 'AF_UNIX'):
arg_parser.error("file system paths not supported by your operating"
" environment")
logging.basicConfig(level=logging.DEBUG)
app = func(extra_argv)
run_app(app, host=args.hostname, port=args.port, path=args.path)
run_app(app, host=args.hostname, port=args.port)
arg_parser.exit(message="Stopped\n")
if __name__ == "__main__": # pragma: no branch
main(sys.argv[1:]) # pragma: no cover

@ -1,514 +0,0 @@
import asyncio
import logging
import warnings
from functools import partial
from typing import ( # noqa
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from . import hdrs
from .abc import (
AbstractAccessLogger,
AbstractMatchInfo,
AbstractRouter,
AbstractStreamWriter,
)
from .frozenlist import FrozenList
from .helpers import DEBUG
from .http_parser import RawRequestMessage
from .log import web_logger
from .signals import Signal
from .streams import StreamReader
from .web_log import AccessLogger
from .web_middlewares import _fix_request_current_app
from .web_protocol import RequestHandler
from .web_request import Request
from .web_response import StreamResponse
from .web_routedef import AbstractRouteDef
from .web_server import Server
from .web_urldispatcher import (
AbstractResource,
Domain,
MaskDomain,
MatchedSubAppResource,
PrefixedSubAppResource,
UrlDispatcher,
)
__all__ = ('Application', 'CleanupError')
if TYPE_CHECKING: # pragma: no cover
_AppSignal = Signal[Callable[['Application'], Awaitable[None]]]
_RespPrepareSignal = Signal[Callable[[Request, StreamResponse],
Awaitable[None]]]
_Handler = Callable[[Request], Awaitable[StreamResponse]]
_Middleware = Union[Callable[[Request, _Handler],
Awaitable[StreamResponse]],
Callable[['Application', _Handler], # old-style
Awaitable[_Handler]]]
_Middlewares = FrozenList[_Middleware]
_MiddlewaresHandlers = Optional[Sequence[Tuple[_Middleware, bool]]]
_Subapps = List['Application']
else:
# No type checker mode, skip types
_AppSignal = Signal
_RespPrepareSignal = Signal
_Handler = Callable
_Middleware = Callable
_Middlewares = FrozenList
_MiddlewaresHandlers = Optional[Sequence]
_Subapps = List
class Application(MutableMapping[str, Any]):
ATTRS = frozenset([
'logger', '_debug', '_router', '_loop', '_handler_args',
'_middlewares', '_middlewares_handlers', '_run_middlewares',
'_state', '_frozen', '_pre_frozen', '_subapps',
'_on_response_prepare', '_on_startup', '_on_shutdown',
'_on_cleanup', '_client_max_size', '_cleanup_ctx'])
def __init__(self, *,
logger: logging.Logger=web_logger,
router: Optional[UrlDispatcher]=None,
middlewares: Iterable[_Middleware]=(),
handler_args: Mapping[str, Any]=None,
client_max_size: int=1024**2,
loop: Optional[asyncio.AbstractEventLoop]=None,
debug: Any=... # mypy doesn't support ellipsis
) -> None:
if router is None:
router = UrlDispatcher()
else:
warnings.warn("router argument is deprecated", DeprecationWarning,
stacklevel=2)
assert isinstance(router, AbstractRouter), router
if loop is not None:
warnings.warn("loop argument is deprecated", DeprecationWarning,
stacklevel=2)
if debug is not ...:
warnings.warn("debug argument is deprecated",
DeprecationWarning,
stacklevel=2)
self._debug = debug
self._router = router # type: UrlDispatcher
self._loop = loop
self._handler_args = handler_args
self.logger = logger
self._middlewares = FrozenList(middlewares) # type: _Middlewares
# initialized on freezing
self._middlewares_handlers = None # type: _MiddlewaresHandlers
# initialized on freezing
self._run_middlewares = None # type: Optional[bool]
self._state = {} # type: Dict[str, Any]
self._frozen = False
self._pre_frozen = False
self._subapps = [] # type: _Subapps
self._on_response_prepare = Signal(self) # type: _RespPrepareSignal
self._on_startup = Signal(self) # type: _AppSignal
self._on_shutdown = Signal(self) # type: _AppSignal
self._on_cleanup = Signal(self) # type: _AppSignal
self._cleanup_ctx = CleanupContext()
self._on_startup.append(self._cleanup_ctx._on_startup)
self._on_cleanup.append(self._cleanup_ctx._on_cleanup)
self._client_max_size = client_max_size
def __init_subclass__(cls: Type['Application']) -> None:
warnings.warn("Inheritance class {} from web.Application "
"is discouraged".format(cls.__name__),
DeprecationWarning,
stacklevel=2)
if DEBUG: # pragma: no cover
def __setattr__(self, name: str, val: Any) -> None:
if name not in self.ATTRS:
warnings.warn("Setting custom web.Application.{} attribute "
"is discouraged".format(name),
DeprecationWarning,
stacklevel=2)
super().__setattr__(name, val)
# MutableMapping API
def __eq__(self, other: object) -> bool:
return self is other
def __getitem__(self, key: str) -> Any:
return self._state[key]
def _check_frozen(self) -> None:
if self._frozen:
warnings.warn("Changing state of started or joined "
"application is deprecated",
DeprecationWarning,
stacklevel=3)
def __setitem__(self, key: str, value: Any) -> None:
self._check_frozen()
self._state[key] = value
def __delitem__(self, key: str) -> None:
self._check_frozen()
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
return iter(self._state)
########
@property
def loop(self) -> asyncio.AbstractEventLoop:
# Technically the loop can be None
# but we mask it by explicit type cast
# to provide more convinient type annotation
warnings.warn("loop property is deprecated",
DeprecationWarning,
stacklevel=2)
return cast(asyncio.AbstractEventLoop, self._loop)
def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None:
if loop is None:
loop = asyncio.get_event_loop()
if self._loop is not None and self._loop is not loop:
raise RuntimeError(
"web.Application instance initialized with different loop")
self._loop = loop
# set loop debug
if self._debug is ...:
self._debug = loop.get_debug()
# set loop to sub applications
for subapp in self._subapps:
subapp._set_loop(loop)
@property
def pre_frozen(self) -> bool:
return self._pre_frozen
def pre_freeze(self) -> None:
if self._pre_frozen:
return
self._pre_frozen = True
self._middlewares.freeze()
self._router.freeze()
self._on_response_prepare.freeze()
self._cleanup_ctx.freeze()
self._on_startup.freeze()
self._on_shutdown.freeze()
self._on_cleanup.freeze()
self._middlewares_handlers = tuple(self._prepare_middleware())
# If current app and any subapp do not have middlewares avoid run all
# of the code footprint that it implies, which have a middleware
# hardcoded per app that sets up the current_app attribute. If no
# middlewares are configured the handler will receive the proper
# current_app without needing all of this code.
self._run_middlewares = True if self.middlewares else False
for subapp in self._subapps:
subapp.pre_freeze()
self._run_middlewares = (self._run_middlewares or
subapp._run_middlewares)
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
if self._frozen:
return
self.pre_freeze()
self._frozen = True
for subapp in self._subapps:
subapp.freeze()
@property
def debug(self) -> bool:
warnings.warn("debug property is deprecated",
DeprecationWarning,
stacklevel=2)
return self._debug
def _reg_subapp_signals(self, subapp: 'Application') -> None:
def reg_handler(signame: str) -> None:
subsig = getattr(subapp, signame)
async def handler(app: 'Application') -> None:
await subsig.send(subapp)
appsig = getattr(self, signame)
appsig.append(handler)
reg_handler('on_startup')
reg_handler('on_shutdown')
reg_handler('on_cleanup')
def add_subapp(self, prefix: str,
subapp: 'Application') -> AbstractResource:
if not isinstance(prefix, str):
raise TypeError("Prefix must be str")
prefix = prefix.rstrip('/')
if not prefix:
raise ValueError("Prefix cannot be empty")
factory = partial(PrefixedSubAppResource, prefix, subapp)
return self._add_subapp(factory, subapp)
def _add_subapp(self,
resource_factory: Callable[[], AbstractResource],
subapp: 'Application') -> AbstractResource:
if self.frozen:
raise RuntimeError(
"Cannot add sub application to frozen application")
if subapp.frozen:
raise RuntimeError("Cannot add frozen application")
resource = resource_factory()
self.router.register_resource(resource)
self._reg_subapp_signals(subapp)
self._subapps.append(subapp)
subapp.pre_freeze()
if self._loop is not None:
subapp._set_loop(self._loop)
return resource
def add_domain(self, domain: str,
subapp: 'Application') -> AbstractResource:
if not isinstance(domain, str):
raise TypeError("Domain must be str")
elif '*' in domain:
rule = MaskDomain(domain) # type: Domain
else:
rule = Domain(domain)
factory = partial(MatchedSubAppResource, rule, subapp)
return self._add_subapp(factory, subapp)
def add_routes(self, routes: Iterable[AbstractRouteDef]) -> None:
self.router.add_routes(routes)
@property
def on_response_prepare(self) -> _RespPrepareSignal:
return self._on_response_prepare
@property
def on_startup(self) -> _AppSignal:
return self._on_startup
@property
def on_shutdown(self) -> _AppSignal:
return self._on_shutdown
@property
def on_cleanup(self) -> _AppSignal:
return self._on_cleanup
@property
def cleanup_ctx(self) -> 'CleanupContext':
return self._cleanup_ctx
@property
def router(self) -> UrlDispatcher:
return self._router
@property
def middlewares(self) -> _Middlewares:
return self._middlewares
def _make_handler(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None,
access_log_class: Type[
AbstractAccessLogger]=AccessLogger,
**kwargs: Any) -> Server:
if not issubclass(access_log_class, AbstractAccessLogger):
raise TypeError(
'access_log_class must be subclass of '
'aiohttp.abc.AbstractAccessLogger, got {}'.format(
access_log_class))
self._set_loop(loop)
self.freeze()
kwargs['debug'] = self._debug
kwargs['access_log_class'] = access_log_class
if self._handler_args:
for k, v in self._handler_args.items():
kwargs[k] = v
return Server(self._handle, # type: ignore
request_factory=self._make_request,
loop=self._loop, **kwargs)
def make_handler(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None,
access_log_class: Type[
AbstractAccessLogger]=AccessLogger,
**kwargs: Any) -> Server:
warnings.warn("Application.make_handler(...) is deprecated, "
"use AppRunner API instead",
DeprecationWarning,
stacklevel=2)
return self._make_handler(loop=loop,
access_log_class=access_log_class,
**kwargs)
async def startup(self) -> None:
"""Causes on_startup signal
Should be called in the event loop along with the request handler.
"""
await self.on_startup.send(self)
async def shutdown(self) -> None:
"""Causes on_shutdown signal
Should be called before cleanup()
"""
await self.on_shutdown.send(self)
async def cleanup(self) -> None:
"""Causes on_cleanup signal
Should be called after shutdown()
"""
await self.on_cleanup.send(self)
def _make_request(self, message: RawRequestMessage,
payload: StreamReader,
protocol: RequestHandler,
writer: AbstractStreamWriter,
task: 'asyncio.Task[None]',
_cls: Type[Request]=Request) -> Request:
return _cls(
message, payload, protocol, writer, task,
self._loop,
client_max_size=self._client_max_size)
def _prepare_middleware(self) -> Iterator[Tuple[_Middleware, bool]]:
for m in reversed(self._middlewares):
if getattr(m, '__middleware_version__', None) == 1:
yield m, True
else:
warnings.warn('old-style middleware "{!r}" deprecated, '
'see #2252'.format(m),
DeprecationWarning, stacklevel=2)
yield m, False
yield _fix_request_current_app(self), True
async def _handle(self, request: Request) -> StreamResponse:
loop = asyncio.get_event_loop()
debug = loop.get_debug()
match_info = await self._router.resolve(request)
if debug: # pragma: no cover
if not isinstance(match_info, AbstractMatchInfo):
raise TypeError("match_info should be AbstractMatchInfo "
"instance, not {!r}".format(match_info))
match_info.add_app(self)
match_info.freeze()
resp = None
request._match_info = match_info # type: ignore
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is None:
handler = match_info.handler
if self._run_middlewares:
for app in match_info.apps[::-1]:
for m, new_style in app._middlewares_handlers: # type: ignore # noqa
if new_style:
handler = partial(m, handler=handler)
else:
handler = await m(app, handler) # type: ignore
resp = await handler(request)
return resp
def __call__(self) -> 'Application':
"""gunicorn compatibility"""
return self
def __repr__(self) -> str:
return "<Application 0x{:x}>".format(id(self))
def __bool__(self) -> bool:
return True
class CleanupError(RuntimeError):
@property
def exceptions(self) -> List[BaseException]:
return self.args[1]
if TYPE_CHECKING: # pragma: no cover
_CleanupContextBase = FrozenList[Callable[[Application],
AsyncIterator[None]]]
else:
_CleanupContextBase = FrozenList
class CleanupContext(_CleanupContextBase):
def __init__(self) -> None:
super().__init__()
self._exits = [] # type: List[AsyncIterator[None]]
async def _on_startup(self, app: Application) -> None:
for cb in self:
it = cb(app).__aiter__()
await it.__anext__()
self._exits.append(it)
async def _on_cleanup(self, app: Application) -> None:
errors = []
for it in reversed(self._exits):
try:
await it.__anext__()
except StopAsyncIteration:
pass
except Exception as exc:
errors.append(exc)
else:
errors.append(RuntimeError("{!r} has more than one 'yield'"
.format(it)))
if errors:
if len(errors) == 1:
raise errors[0]
else:
raise CleanupError("Multiple errors on cleanup stage", errors)

@ -1,10 +1,4 @@
import warnings
from typing import Any, Dict, Iterable, List, Optional, Set # noqa
from yarl import URL
from .typedefs import LooseHeaders, StrOrURL
from .web_response import Response
from .web_reqrep import Response
__all__ = (
'HTTPException',
@ -46,8 +40,6 @@ __all__ = (
'HTTPRequestRangeNotSatisfiable',
'HTTPExpectationFailed',
'HTTPMisdirectedRequest',
'HTTPUnprocessableEntity',
'HTTPFailedDependency',
'HTTPUpgradeRequired',
'HTTPPreconditionRequired',
'HTTPTooManyRequests',
@ -61,7 +53,6 @@ __all__ = (
'HTTPGatewayTimeout',
'HTTPVersionNotSupported',
'HTTPVariantAlsoNegotiates',
'HTTPInsufficientStorage',
'HTTPNotExtended',
'HTTPNetworkAuthenticationRequired',
)
@ -76,21 +67,11 @@ class HTTPException(Response, Exception):
# You should set in subclasses:
# status = 200
status_code = -1
status_code = None
empty_body = False
__http_exception__ = True
def __init__(self, *,
headers: Optional[LooseHeaders]=None,
reason: Optional[str]=None,
body: Any=None,
text: Optional[str]=None,
content_type: Optional[str]=None) -> None:
if body is not None:
warnings.warn(
"body argument is deprecated for http web exceptions",
DeprecationWarning)
def __init__(self, *, headers=None, reason=None,
body=None, text=None, content_type=None):
Response.__init__(self, status=self.status_code,
headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
@ -98,9 +79,6 @@ class HTTPException(Response, Exception):
if self.body is None and not self.empty_body:
self.text = "{}: {}".format(self.status, self.reason)
def __bool__(self) -> bool:
return True
class HTTPError(HTTPException):
"""Base class for exceptions with status codes in the 400s and 500s."""
@ -151,19 +129,13 @@ class HTTPPartialContent(HTTPSuccessful):
class _HTTPMove(HTTPRedirection):
def __init__(self,
location: StrOrURL,
*,
headers: Optional[LooseHeaders]=None,
reason: Optional[str]=None,
body: Any=None,
text: Optional[str]=None,
content_type: Optional[str]=None) -> None:
def __init__(self, location, *, headers=None, reason=None,
body=None, text=None, content_type=None):
if not location:
raise ValueError("HTTP redirects need a location to redirect to.")
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Location'] = str(URL(location))
self.headers['Location'] = location
self.location = location
@ -236,20 +208,13 @@ class HTTPNotFound(HTTPClientError):
class HTTPMethodNotAllowed(HTTPClientError):
status_code = 405
def __init__(self,
method: str,
allowed_methods: Iterable[str],
*,
headers: Optional[LooseHeaders]=None,
reason: Optional[str]=None,
body: Any=None,
text: Optional[str]=None,
content_type: Optional[str]=None) -> None:
def __init__(self, method, allowed_methods, *, headers=None, reason=None,
body=None, text=None, content_type=None):
allow = ','.join(sorted(allowed_methods))
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Allow'] = allow
self.allowed_methods = set(allowed_methods) # type: Set[str]
self.allowed_methods = allowed_methods
self.method = method.upper()
@ -284,17 +249,6 @@ class HTTPPreconditionFailed(HTTPClientError):
class HTTPRequestEntityTooLarge(HTTPClientError):
status_code = 413
def __init__(self,
max_size: float,
actual_size: float,
**kwargs: Any) -> None:
kwargs.setdefault(
'text',
'Maximum request body size {} exceeded, '
'actual body size {}'.format(max_size, actual_size)
)
super().__init__(**kwargs)
class HTTPRequestURITooLong(HTTPClientError):
status_code = 414
@ -316,14 +270,6 @@ class HTTPMisdirectedRequest(HTTPClientError):
status_code = 421
class HTTPUnprocessableEntity(HTTPClientError):
status_code = 422
class HTTPFailedDependency(HTTPClientError):
status_code = 424
class HTTPUpgradeRequired(HTTPClientError):
status_code = 426
@ -343,14 +289,8 @@ class HTTPRequestHeaderFieldsTooLarge(HTTPClientError):
class HTTPUnavailableForLegalReasons(HTTPClientError):
status_code = 451
def __init__(self,
link: str,
*,
headers: Optional[LooseHeaders]=None,
reason: Optional[str]=None,
body: Any=None,
text: Optional[str]=None,
content_type: Optional[str]=None) -> None:
def __init__(self, link, *, headers=None, reason=None,
body=None, text=None, content_type=None):
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Link'] = '<%s>; rel="blocked-by"' % link
@ -401,10 +341,6 @@ class HTTPVariantAlsoNegotiates(HTTPServerError):
status_code = 506
class HTTPInsufficientStorage(HTTPServerError):
status_code = 507
class HTTPNotExtended(HTTPServerError):
status_code = 510

@ -1,346 +0,0 @@
import asyncio
import mimetypes
import os
import pathlib
from functools import partial
from typing import ( # noqa
IO,
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Optional,
Union,
cast,
)
from . import hdrs
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import set_exception, set_result
from .http_writer import StreamWriter
from .log import server_logger
from .typedefs import LooseHeaders
from .web_exceptions import (
HTTPNotModified,
HTTPPartialContent,
HTTPPreconditionFailed,
HTTPRequestRangeNotSatisfiable,
)
from .web_response import StreamResponse
__all__ = ('FileResponse',)
if TYPE_CHECKING: # pragma: no cover
from .web_request import BaseRequest # noqa
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
NOSENDFILE = bool(os.environ.get("AIOHTTP_NOSENDFILE"))
class SendfileStreamWriter(StreamWriter):
def __init__(self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
fobj: IO[Any],
count: int,
on_chunk_sent: _T_OnChunkSent=None) -> None:
super().__init__(protocol, loop, on_chunk_sent)
self._sendfile_buffer = [] # type: List[bytes]
self._fobj = fobj
self._count = count
self._offset = fobj.tell()
self._in_fd = fobj.fileno()
def _write(self, chunk: bytes) -> None:
# we overwrite StreamWriter._write, so nothing can be appended to
# _buffer, and nothing is written to the transport directly by the
# parent class
self.output_size += len(chunk)
self._sendfile_buffer.append(chunk)
def _sendfile_cb(self, fut: 'asyncio.Future[None]', out_fd: int) -> None:
if fut.cancelled():
return
try:
if self._do_sendfile(out_fd):
set_result(fut, None)
except Exception as exc:
set_exception(fut, exc)
def _do_sendfile(self, out_fd: int) -> bool:
try:
n = os.sendfile(out_fd,
self._in_fd,
self._offset,
self._count)
if n == 0: # in_fd EOF reached
n = self._count
except (BlockingIOError, InterruptedError):
n = 0
self.output_size += n
self._offset += n
self._count -= n
assert self._count >= 0
return self._count == 0
def _done_fut(self, out_fd: int, fut: 'asyncio.Future[None]') -> None:
self.loop.remove_writer(out_fd)
async def sendfile(self) -> None:
assert self.transport is not None
out_socket = self.transport.get_extra_info('socket').dup()
out_socket.setblocking(False)
out_fd = out_socket.fileno()
loop = self.loop
data = b''.join(self._sendfile_buffer)
try:
await loop.sock_sendall(out_socket, data)
if not self._do_sendfile(out_fd):
fut = loop.create_future()
fut.add_done_callback(partial(self._done_fut, out_fd))
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd)
await fut
except asyncio.CancelledError:
raise
except Exception:
server_logger.debug('Socket error')
self.transport.close()
finally:
out_socket.close()
await super().write_eof()
async def write_eof(self, chunk: bytes=b'') -> None:
pass
class FileResponse(StreamResponse):
"""A response object can be used to send files."""
def __init__(self, path: Union[str, pathlib.Path],
chunk_size: int=256*1024,
status: int=200,
reason: Optional[str]=None,
headers: Optional[LooseHeaders]=None) -> None:
super().__init__(status=status, reason=reason, headers=headers)
if isinstance(path, str):
path = pathlib.Path(path)
self._path = path
self._chunk_size = chunk_size
async def _sendfile_system(self, request: 'BaseRequest',
fobj: IO[Any],
count: int) -> AbstractStreamWriter:
# Write count bytes of fobj to resp using
# the os.sendfile system call.
#
# For details check
# https://github.com/KeepSafe/aiohttp/issues/1177
# See https://github.com/KeepSafe/aiohttp/issues/958 for details
#
# request should be an aiohttp.web.Request instance.
# fobj should be an open file object.
# count should be an integer > 0.
transport = request.transport
assert transport is not None
if (transport.get_extra_info("sslcontext") or
transport.get_extra_info("socket") is None or
self.compression):
writer = await self._sendfile_fallback(request, fobj, count)
else:
writer = SendfileStreamWriter(
request.protocol,
request._loop,
fobj,
count
)
request._payload_writer = writer
await super().prepare(request)
await writer.sendfile()
return writer
async def _sendfile_fallback(self, request: 'BaseRequest',
fobj: IO[Any],
count: int) -> AbstractStreamWriter:
# Mimic the _sendfile_system() method, but without using the
# os.sendfile() system call. This should be used on systems
# that don't support the os.sendfile().
# To keep memory usage low,fobj is transferred in chunks
# controlled by the constructor's chunk_size argument.
writer = await super().prepare(request)
assert writer is not None
chunk_size = self._chunk_size
loop = asyncio.get_event_loop()
chunk = await loop.run_in_executor(None, fobj.read, chunk_size)
while chunk:
await writer.write(chunk)
count = count - chunk_size
if count <= 0:
break
chunk = await loop.run_in_executor(
None, fobj.read, min(chunk_size, count)
)
await writer.drain()
return writer
if hasattr(os, "sendfile") and not NOSENDFILE: # pragma: no cover
_sendfile = _sendfile_system
else: # pragma: no cover
_sendfile = _sendfile_fallback
async def prepare(
self,
request: 'BaseRequest'
) -> Optional[AbstractStreamWriter]:
filepath = self._path
gzip = False
if 'gzip' in request.headers.get(hdrs.ACCEPT_ENCODING, ''):
gzip_path = filepath.with_name(filepath.name + '.gz')
if gzip_path.is_file():
filepath = gzip_path
gzip = True
loop = asyncio.get_event_loop()
st = await loop.run_in_executor(None, filepath.stat)
modsince = request.if_modified_since
if modsince is not None and st.st_mtime <= modsince.timestamp():
self.set_status(HTTPNotModified.status_code)
self._length_check = False
# Delete any Content-Length headers provided by user. HTTP 304
# should always have empty response body
return await super().prepare(request)
unmodsince = request.if_unmodified_since
if unmodsince is not None and st.st_mtime > unmodsince.timestamp():
self.set_status(HTTPPreconditionFailed.status_code)
return await super().prepare(request)
if hdrs.CONTENT_TYPE not in self.headers:
ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
ct = 'application/octet-stream'
should_set_ct = True
else:
encoding = 'gzip' if gzip else None
should_set_ct = False
status = self._status
file_size = st.st_size
count = file_size
start = None
ifrange = request.if_range
if ifrange is None or st.st_mtime <= ifrange.timestamp():
# If-Range header check:
# condition = cached date >= last modification date
# return 206 if True else 200.
# if False:
# Range header would not be processed, return 200
# if True but Range header missing
# return 200
try:
rng = request.http_range
start = rng.start
end = rng.stop
except ValueError:
# https://tools.ietf.org/html/rfc7233:
# A server generating a 416 (Range Not Satisfiable) response to
# a byte-range request SHOULD send a Content-Range header field
# with an unsatisfied-range value.
# The complete-length in a 416 response indicates the current
# length of the selected representation.
#
# Will do the same below. Many servers ignore this and do not
# send a Content-Range header with HTTP 416
self.headers[hdrs.CONTENT_RANGE] = 'bytes */{0}'.format(
file_size)
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return await super().prepare(request)
# If a range request has been made, convert start, end slice
# notation into file pointer offset and count
if start is not None or end is not None:
if start < 0 and end is None: # return tail of file
start += file_size
if start < 0:
# if Range:bytes=-1000 in request header but file size
# is only 200, there would be trouble without this
start = 0
count = file_size - start
else:
# rfc7233:If the last-byte-pos value is
# absent, or if the value is greater than or equal to
# the current length of the representation data,
# the byte range is interpreted as the remainder
# of the representation (i.e., the server replaces the
# value of last-byte-pos with a value that is one less than
# the current length of the selected representation).
count = min(end if end is not None else file_size,
file_size) - start
if start >= file_size:
# HTTP 416 should be returned in this case.
#
# According to https://tools.ietf.org/html/rfc7233:
# If a valid byte-range-set includes at least one
# byte-range-spec with a first-byte-pos that is less than
# the current length of the representation, or at least one
# suffix-byte-range-spec with a non-zero suffix-length,
# then the byte-range-set is satisfiable. Otherwise, the
# byte-range-set is unsatisfiable.
self.headers[hdrs.CONTENT_RANGE] = 'bytes */{0}'.format(
file_size)
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return await super().prepare(request)
status = HTTPPartialContent.status_code
# Even though you are sending the whole file, you should still
# return a HTTP 206 for a Range request.
self.set_status(status)
if should_set_ct:
self.content_type = ct # type: ignore
if encoding:
self.headers[hdrs.CONTENT_ENCODING] = encoding
if gzip:
self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
self.last_modified = st.st_mtime # type: ignore
self.content_length = count
self.headers[hdrs.ACCEPT_RANGES] = 'bytes'
real_start = cast(int, start)
if status == HTTPPartialContent.status_code:
self.headers[hdrs.CONTENT_RANGE] = 'bytes {0}-{1}/{2}'.format(
real_start, real_start + count - 1, file_size)
fobj = await loop.run_in_executor(None, filepath.open, 'rb')
if start: # be aware that start could be None or int=0 here.
await loop.run_in_executor(None, fobj.seek, start)
try:
return await self._sendfile(request, fobj, count)
finally:
await loop.run_in_executor(None, fobj.close)

@ -1,235 +0,0 @@
import datetime
import functools
import logging
import os
import re
from collections import namedtuple
from typing import Any, Callable, Dict, Iterable, List, Tuple # noqa
from .abc import AbstractAccessLogger
from .web_request import BaseRequest
from .web_response import StreamResponse
KeyMethod = namedtuple('KeyMethod', 'key method')
class AccessLogger(AbstractAccessLogger):
"""Helper object to log access.
Usage:
log = logging.getLogger("spam")
log_format = "%a %{User-Agent}i"
access_logger = AccessLogger(log, log_format)
access_logger.log(request, response, time)
Format:
%% The percent sign
%a Remote IP-address (IP-address of proxy if using reverse proxy)
%t Time when the request was started to process
%P The process ID of the child that serviced the request
%r First line of request
%s Response status code
%b Size of response in bytes, including HTTP headers
%T Time taken to serve the request, in seconds
%Tf Time taken to serve the request, in seconds with floating fraction
in .06f format
%D Time taken to serve the request, in microseconds
%{FOO}i request.headers['FOO']
%{FOO}o response.headers['FOO']
%{FOO}e os.environ['FOO']
"""
LOG_FORMAT_MAP = {
'a': 'remote_address',
't': 'request_start_time',
'P': 'process_id',
'r': 'first_request_line',
's': 'response_status',
'b': 'response_size',
'T': 'request_time',
'Tf': 'request_time_frac',
'D': 'request_time_micro',
'i': 'request_header',
'o': 'response_header',
}
LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"'
FORMAT_RE = re.compile(r'%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)')
CLEANUP_RE = re.compile(r'(%[^s])')
_FORMAT_CACHE = {} # type: Dict[str, Tuple[str, List[KeyMethod]]]
def __init__(self, logger: logging.Logger,
log_format: str=LOG_FORMAT) -> None:
"""Initialise the logger.
logger is a logger object to be used for logging.
log_format is a string with apache compatible log format description.
"""
super().__init__(logger, log_format=log_format)
_compiled_format = AccessLogger._FORMAT_CACHE.get(log_format)
if not _compiled_format:
_compiled_format = self.compile_format(log_format)
AccessLogger._FORMAT_CACHE[log_format] = _compiled_format
self._log_format, self._methods = _compiled_format
def compile_format(self, log_format: str) -> Tuple[str, List[KeyMethod]]:
"""Translate log_format into form usable by modulo formatting
All known atoms will be replaced with %s
Also methods for formatting of those atoms will be added to
_methods in appropriate order
For example we have log_format = "%a %t"
This format will be translated to "%s %s"
Also contents of _methods will be
[self._format_a, self._format_t]
These method will be called and results will be passed
to translated string format.
Each _format_* method receive 'args' which is list of arguments
given to self.log
Exceptions are _format_e, _format_i and _format_o methods which
also receive key name (by functools.partial)
"""
# list of (key, method) tuples, we don't use an OrderedDict as users
# can repeat the same key more than once
methods = list()
for atom in self.FORMAT_RE.findall(log_format):
if atom[1] == '':
format_key1 = self.LOG_FORMAT_MAP[atom[0]]
m = getattr(AccessLogger, '_format_%s' % atom[0])
key_method = KeyMethod(format_key1, m)
else:
format_key2 = (self.LOG_FORMAT_MAP[atom[2]], atom[1])
m = getattr(AccessLogger, '_format_%s' % atom[2])
key_method = KeyMethod(format_key2,
functools.partial(m, atom[1]))
methods.append(key_method)
log_format = self.FORMAT_RE.sub(r'%s', log_format)
log_format = self.CLEANUP_RE.sub(r'%\1', log_format)
return log_format, methods
@staticmethod
def _format_i(key: str,
request: BaseRequest,
response: StreamResponse,
time: float) -> str:
if request is None:
return '(no headers)'
# suboptimal, make istr(key) once
return request.headers.get(key, '-')
@staticmethod
def _format_o(key: str,
request: BaseRequest,
response: StreamResponse,
time: float) -> str:
# suboptimal, make istr(key) once
return response.headers.get(key, '-')
@staticmethod
def _format_a(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
if request is None:
return '-'
ip = request.remote
return ip if ip is not None else '-'
@staticmethod
def _format_t(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
now = datetime.datetime.utcnow()
start_time = now - datetime.timedelta(seconds=time)
return start_time.strftime('[%d/%b/%Y:%H:%M:%S +0000]')
@staticmethod
def _format_P(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
return "<%s>" % os.getpid()
@staticmethod
def _format_r(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
if request is None:
return '-'
return '%s %s HTTP/%s.%s' % (request.method, request.path_qs,
request.version.major,
request.version.minor)
@staticmethod
def _format_s(request: BaseRequest,
response: StreamResponse,
time: float) -> int:
return response.status
@staticmethod
def _format_b(request: BaseRequest,
response: StreamResponse,
time: float) -> int:
return response.body_length
@staticmethod
def _format_T(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
return str(round(time))
@staticmethod
def _format_Tf(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
return '%06f' % time
@staticmethod
def _format_D(request: BaseRequest,
response: StreamResponse,
time: float) -> str:
return str(round(time * 1000000))
def _format_line(self,
request: BaseRequest,
response: StreamResponse,
time: float) -> Iterable[Tuple[str,
Callable[[BaseRequest,
StreamResponse,
float],
str]]]:
return [(key, method(request, response, time))
for key, method in self._methods]
def log(self,
request: BaseRequest,
response: StreamResponse,
time: float) -> None:
try:
fmt_info = self._format_line(request, response, time)
values = list()
extra = dict()
for key, value in fmt_info:
values.append(value)
if key.__class__ is str:
extra[key] = value
else:
k1, k2 = key
dct = extra.get(k1, {}) # type: Any
dct[k2] = value
extra[k1] = dct
self.logger.info(self._log_format % tuple(values), extra=extra)
except Exception:
self.logger.exception("Error in logging")

@ -1,120 +0,0 @@
import re
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, Type, TypeVar
from .web_exceptions import HTTPPermanentRedirect, _HTTPMove
from .web_request import Request
from .web_response import StreamResponse
from .web_urldispatcher import SystemRoute
__all__ = (
'middleware',
'normalize_path_middleware',
)
if TYPE_CHECKING: # pragma: no cover
from .web_app import Application # noqa
_Func = TypeVar('_Func')
async def _check_request_resolves(request: Request,
path: str) -> Tuple[bool, Request]:
alt_request = request.clone(rel_url=path)
match_info = await request.app.router.resolve(alt_request)
alt_request._match_info = match_info # type: ignore
if match_info.http_exception is None:
return True, alt_request
return False, request
def middleware(f: _Func) -> _Func:
f.__middleware_version__ = 1 # type: ignore
return f
_Handler = Callable[[Request], Awaitable[StreamResponse]]
_Middleware = Callable[[Request, _Handler], Awaitable[StreamResponse]]
def normalize_path_middleware(
*, append_slash: bool=True, remove_slash: bool=False,
merge_slashes: bool=True,
redirect_class: Type[_HTTPMove]=HTTPPermanentRedirect) -> _Middleware:
"""
Middleware factory which produces a middleware that normalizes
the path of a request. By normalizing it means:
- Add or remove a trailing slash to the path.
- Double slashes are replaced by one.
The middleware returns as soon as it finds a path that resolves
correctly. The order if both merge and append/remove are enabled is
1) merge slashes
2) append/remove slash
3) both merge slashes and append/remove slash.
If the path resolves with at least one of those conditions, it will
redirect to the new path.
Only one of `append_slash` and `remove_slash` can be enabled. If both
are `True` the factory will raise an assertion error
If `append_slash` is `True` the middleware will append a slash when
needed. If a resource is defined with trailing slash and the request
comes without it, it will append it automatically.
If `remove_slash` is `True`, `append_slash` must be `False`. When enabled
the middleware will remove trailing slashes and redirect if the resource
is defined
If merge_slashes is True, merge multiple consecutive slashes in the
path into one.
"""
correct_configuration = not (append_slash and remove_slash)
assert correct_configuration, "Cannot both remove and append slash"
@middleware
async def impl(request: Request, handler: _Handler) -> StreamResponse:
if isinstance(request.match_info.route, SystemRoute):
paths_to_check = []
if '?' in request.raw_path:
path, query = request.raw_path.split('?', 1)
query = '?' + query
else:
query = ''
path = request.raw_path
if merge_slashes:
paths_to_check.append(re.sub('//+', '/', path))
if append_slash and not request.path.endswith('/'):
paths_to_check.append(path + '/')
if remove_slash and request.path.endswith('/'):
paths_to_check.append(path[:-1])
if merge_slashes and append_slash:
paths_to_check.append(
re.sub('//+', '/', path + '/'))
if merge_slashes and remove_slash:
merged_slashes = re.sub('//+', '/', path)
paths_to_check.append(merged_slashes[:-1])
for path in paths_to_check:
resolves, request = await _check_request_resolves(
request, path)
if resolves:
raise redirect_class(request.raw_path + query)
return await handler(request)
return impl
def _fix_request_current_app(app: 'Application') -> _Middleware:
@middleware
async def impl(request: Request, handler: _Handler) -> StreamResponse:
with request.match_info.set_current_app(app):
return await handler(request)
return impl

@ -1,599 +0,0 @@
import asyncio
import asyncio.streams
import traceback
import warnings
from collections import deque
from contextlib import suppress
from html import escape as html_escape
from http import HTTPStatus
from logging import Logger
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
Type,
cast,
)
import yarl
from .abc import AbstractAccessLogger, AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import CeilTimeout, current_task
from .http import (
HttpProcessingError,
HttpRequestParser,
HttpVersion10,
RawRequestMessage,
StreamWriter,
)
from .log import access_logger, server_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .tcp_helpers import tcp_keepalive
from .web_exceptions import HTTPException
from .web_log import AccessLogger
from .web_request import BaseRequest
from .web_response import Response, StreamResponse
__all__ = ('RequestHandler', 'RequestPayloadError', 'PayloadAccessError')
if TYPE_CHECKING: # pragma: no cover
from .web_server import Server # noqa
_RequestFactory = Callable[[RawRequestMessage,
StreamReader,
'RequestHandler',
AbstractStreamWriter,
'asyncio.Task[None]'],
BaseRequest]
_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
ERROR = RawRequestMessage(
'UNKNOWN', '/', HttpVersion10, {},
{}, True, False, False, False, yarl.URL('/'))
class RequestPayloadError(Exception):
"""Payload parsing error."""
class PayloadAccessError(Exception):
"""Payload was accessed after response was sent."""
class RequestHandler(BaseProtocol):
"""HTTP protocol implementation.
RequestHandler handles incoming HTTP request. It reads request line,
request headers and request payload and calls handle_request() method.
By default it always returns with 404 response.
RequestHandler handles errors in incoming request, like bad
status line, bad headers or incomplete payload. If any error occurs,
connection gets closed.
:param keepalive_timeout: number of seconds before closing
keep-alive connection
:type keepalive_timeout: int or None
:param bool tcp_keepalive: TCP keep-alive is on, default is on
:param bool debug: enable debug mode
:param logger: custom logger object
:type logger: aiohttp.log.server_logger
:param access_log_class: custom class for access_logger
:type access_log_class: aiohttp.abc.AbstractAccessLogger
:param access_log: custom logging object
:type access_log: aiohttp.log.server_logger
:param str access_log_format: access log format string
:param loop: Optional event loop
:param int max_line_size: Optional maximum header line size
:param int max_field_size: Optional maximum header field size
:param int max_headers: Optional maximum header size
"""
KEEPALIVE_RESCHEDULE_DELAY = 1
__slots__ = ('_request_count', '_keepalive', '_manager',
'_request_handler', '_request_factory', '_tcp_keepalive',
'_keepalive_time', '_keepalive_handle', '_keepalive_timeout',
'_lingering_time', '_messages', '_message_tail',
'_waiter', '_error_handler', '_task_handler',
'_upgrade', '_payload_parser', '_request_parser',
'_reading_paused', 'logger', 'debug', 'access_log',
'access_logger', '_close', '_force_close')
def __init__(self, manager: 'Server', *,
loop: asyncio.AbstractEventLoop,
keepalive_timeout: float=75., # NGINX default is 75 secs
tcp_keepalive: bool=True,
logger: Logger=server_logger,
access_log_class: Type[AbstractAccessLogger]=AccessLogger,
access_log: Logger=access_logger,
access_log_format: str=AccessLogger.LOG_FORMAT,
debug: bool=False,
max_line_size: int=8190,
max_headers: int=32768,
max_field_size: int=8190,
lingering_time: float=10.0):
super().__init__(loop)
self._request_count = 0
self._keepalive = False
self._manager = manager # type: Optional[Server]
self._request_handler = manager.request_handler # type: Optional[_RequestHandler] # noqa
self._request_factory = manager.request_factory # type: Optional[_RequestFactory] # noqa
self._tcp_keepalive = tcp_keepalive
# placeholder to be replaced on keepalive timeout setup
self._keepalive_time = 0.0
self._keepalive_handle = None # type: Optional[asyncio.Handle]
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)
self._messages = deque() # type: Any # Python 3.5 has no typing.Deque
self._message_tail = b''
self._waiter = None # type: Optional[asyncio.Future[None]]
self._error_handler = None # type: Optional[asyncio.Task[None]]
self._task_handler = None # type: Optional[asyncio.Task[None]]
self._upgrade = False
self._payload_parser = None # type: Any
self._request_parser = HttpRequestParser(
self, loop,
max_line_size=max_line_size,
max_field_size=max_field_size,
max_headers=max_headers,
payload_exception=RequestPayloadError) # type: Optional[HttpRequestParser] # noqa
self.logger = logger
self.debug = debug
self.access_log = access_log
if access_log:
self.access_logger = access_log_class(
access_log, access_log_format) # type: Optional[AbstractAccessLogger] # noqa
else:
self.access_logger = None
self._close = False
self._force_close = False
def __repr__(self) -> str:
return "<{} {}>".format(
self.__class__.__name__,
'connected' if self.transport is not None else 'disconnected')
@property
def keepalive_timeout(self) -> float:
return self._keepalive_timeout
async def shutdown(self, timeout: Optional[float]=15.0) -> None:
"""Worker process is about to exit, we need cleanup everything and
stop accepting requests. It is especially important for keep-alive
connections."""
self._force_close = True
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
if self._waiter:
self._waiter.cancel()
# wait for handlers
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
with CeilTimeout(timeout, loop=self._loop):
if (self._error_handler is not None and
not self._error_handler.done()):
await self._error_handler
if (self._task_handler is not None and
not self._task_handler.done()):
await self._task_handler
# force-close non-idle handler
if self._task_handler is not None:
self._task_handler.cancel()
if self.transport is not None:
self.transport.close()
self.transport = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
super().connection_made(transport)
real_transport = cast(asyncio.Transport, transport)
if self._tcp_keepalive:
tcp_keepalive(real_transport)
self._task_handler = self._loop.create_task(self.start())
assert self._manager is not None
self._manager.connection_made(self, real_transport)
def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._manager is None:
return
self._manager.connection_lost(self, exc)
super().connection_lost(exc)
self._manager = None
self._force_close = True
self._request_factory = None
self._request_handler = None
self._request_parser = None
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
if self._task_handler is not None:
self._task_handler.cancel()
if self._error_handler is not None:
self._error_handler.cancel()
self._task_handler = None
if self._payload_parser is not None:
self._payload_parser.feed_eof()
self._payload_parser = None
def set_parser(self, parser: Any) -> None:
# Actual type is WebReader
assert self._payload_parser is None
self._payload_parser = parser
if self._message_tail:
self._payload_parser.feed_data(self._message_tail)
self._message_tail = b''
def eof_received(self) -> None:
pass
def data_received(self, data: bytes) -> None:
if self._force_close or self._close:
return
# parse http messages
if self._payload_parser is None and not self._upgrade:
assert self._request_parser is not None
try:
messages, upgraded, tail = self._request_parser.feed_data(data)
except HttpProcessingError as exc:
# something happened during parsing
self._error_handler = self._loop.create_task(
self.handle_parse_error(
StreamWriter(self, self._loop),
400, exc, exc.message))
self.close()
except Exception as exc:
# 500: internal error
self._error_handler = self._loop.create_task(
self.handle_parse_error(
StreamWriter(self, self._loop),
500, exc))
self.close()
else:
if messages:
# sometimes the parser returns no messages
for (msg, payload) in messages:
self._request_count += 1
self._messages.append((msg, payload))
waiter = self._waiter
if waiter is not None:
if not waiter.done():
# don't set result twice
waiter.set_result(None)
self._upgrade = upgraded
if upgraded and tail:
self._message_tail = tail
# no parser, just store
elif self._payload_parser is None and self._upgrade and data:
self._message_tail += data
# feed payload
elif data:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self.close()
def keep_alive(self, val: bool) -> None:
"""Set keep-alive connection mode.
:param bool val: new state.
"""
self._keepalive = val
if self._keepalive_handle:
self._keepalive_handle.cancel()
self._keepalive_handle = None
def close(self) -> None:
"""Stop accepting new pipelinig messages and close
connection when handlers done processing messages"""
self._close = True
if self._waiter:
self._waiter.cancel()
def force_close(self) -> None:
"""Force close connection"""
self._force_close = True
if self._waiter:
self._waiter.cancel()
if self.transport is not None:
self.transport.close()
self.transport = None
def log_access(self,
request: BaseRequest,
response: StreamResponse,
time: float) -> None:
if self.access_logger is not None:
self.access_logger.log(request, response, time)
def log_debug(self, *args: Any, **kw: Any) -> None:
if self.debug:
self.logger.debug(*args, **kw)
def log_exception(self, *args: Any, **kw: Any) -> None:
self.logger.exception(*args, **kw)
def _process_keepalive(self) -> None:
if self._force_close or not self._keepalive:
return
next = self._keepalive_time + self._keepalive_timeout
# handler in idle state
if self._waiter:
if self._loop.time() > next:
self.force_close()
return
# not all request handlers are done,
# reschedule itself to next second
self._keepalive_handle = self._loop.call_later(
self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive)
async def start(self) -> None:
"""Process incoming request.
It reads request line, request headers and request payload, then
calls handle_request() method. Subclass has to override
handle_request(). start() handles various exceptions in request
or response handling. Connection is being closed always unless
keep_alive(True) specified.
"""
loop = self._loop
handler = self._task_handler
assert handler is not None
manager = self._manager
assert manager is not None
keepalive_timeout = self._keepalive_timeout
resp = None
assert self._request_factory is not None
assert self._request_handler is not None
while not self._force_close:
if not self._messages:
try:
# wait for next request
self._waiter = loop.create_future()
await self._waiter
except asyncio.CancelledError:
break
finally:
self._waiter = None
message, payload = self._messages.popleft()
if self.access_log:
now = loop.time()
manager.requests_count += 1
writer = StreamWriter(self, loop)
request = self._request_factory(
message, payload, self, writer, handler)
try:
# a new task is used for copy context vars (#3406)
task = self._loop.create_task(
self._request_handler(request))
try:
resp = await task
except HTTPException as exc:
resp = exc
except (asyncio.CancelledError, ConnectionError):
self.log_debug('Ignored premature client disconnection')
break
except asyncio.TimeoutError as exc:
self.log_debug('Request handler timed out.', exc_info=exc)
resp = self.handle_error(request, 504)
except Exception as exc:
resp = self.handle_error(request, 500, exc)
else:
# Deprecation warning (See #2415)
if getattr(resp, '__http_exception__', False):
warnings.warn(
"returning HTTPException object is deprecated "
"(#2415) and will be removed, "
"please raise the exception instead",
DeprecationWarning)
# Drop the processed task from asyncio.Task.all_tasks() early
del task
if self.debug:
if not isinstance(resp, StreamResponse):
if resp is None:
raise RuntimeError("Missing return "
"statement on request handler")
else:
raise RuntimeError("Web-handler should return "
"a response instance, "
"got {!r}".format(resp))
try:
prepare_meth = resp.prepare
except AttributeError:
if resp is None:
raise RuntimeError("Missing return "
"statement on request handler")
else:
raise RuntimeError("Web-handler should return "
"a response instance, "
"got {!r}".format(resp))
try:
await prepare_meth(request)
await resp.write_eof()
except ConnectionError:
self.log_debug('Ignored premature client disconnection 2')
break
# notify server about keep-alive
self._keepalive = bool(resp.keep_alive)
# log access
if self.access_log:
self.log_access(request, resp, loop.time() - now)
# check payload
if not payload.is_eof():
lingering_time = self._lingering_time
if not self._force_close and lingering_time:
self.log_debug(
'Start lingering close timer for %s sec.',
lingering_time)
now = loop.time()
end_t = now + lingering_time
with suppress(
asyncio.TimeoutError, asyncio.CancelledError):
while not payload.is_eof() and now < end_t:
with CeilTimeout(end_t - now, loop=loop):
# read and ignore
await payload.readany()
now = loop.time()
# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
self.log_debug('Uncompleted request.')
self.close()
payload.set_exception(PayloadAccessError())
except asyncio.CancelledError:
self.log_debug('Ignored premature client disconnection ')
break
except RuntimeError as exc:
if self.debug:
self.log_exception(
'Unhandled runtime exception', exc_info=exc)
self.force_close()
except Exception as exc:
self.log_exception('Unhandled exception', exc_info=exc)
self.force_close()
finally:
if self.transport is None and resp is not None:
self.log_debug('Ignored premature client disconnection.')
elif not self._force_close:
if self._keepalive and not self._close:
# start keep-alive timer
if keepalive_timeout is not None:
now = self._loop.time()
self._keepalive_time = now
if self._keepalive_handle is None:
self._keepalive_handle = loop.call_at(
now + keepalive_timeout,
self._process_keepalive)
else:
break
# remove handler, close transport if no handlers left
if not self._force_close:
self._task_handler = None
if self.transport is not None and self._error_handler is None:
self.transport.close()
def handle_error(self,
request: BaseRequest,
status: int=500,
exc: Optional[BaseException]=None,
message: Optional[str]=None) -> StreamResponse:
"""Handle errors.
Returns HTTP response with specific status code. Logs additional
information. It always closes current connection."""
self.log_exception("Error handling request", exc_info=exc)
ct = 'text/plain'
if status == HTTPStatus.INTERNAL_SERVER_ERROR:
title = '{0.value} {0.phrase}'.format(
HTTPStatus.INTERNAL_SERVER_ERROR
)
msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
tb = None
if self.debug:
with suppress(Exception):
tb = traceback.format_exc()
if 'text/html' in request.headers.get('Accept', ''):
if tb:
tb = html_escape(tb)
msg = '<h2>Traceback:</h2>\n<pre>{}</pre>'.format(tb)
message = (
"<html><head>"
"<title>{title}</title>"
"</head><body>\n<h1>{title}</h1>"
"\n{msg}\n</body></html>\n"
).format(title=title, msg=msg)
ct = 'text/html'
else:
if tb:
msg = tb
message = title + '\n\n' + msg
resp = Response(status=status, text=message, content_type=ct)
resp.force_close()
# some data already got sent, connection is broken
if request.writer.output_size > 0 or self.transport is None:
self.force_close()
return resp
async def handle_parse_error(self,
writer: AbstractStreamWriter,
status: int,
exc: Optional[BaseException]=None,
message: Optional[str]=None) -> None:
request = BaseRequest( # type: ignore
ERROR,
EMPTY_PAYLOAD,
self, writer,
current_task(),
self._loop)
resp = self.handle_error(request, status, exc, message)
await resp.prepare(request)
await resp.write_eof()
if self.transport is not None:
self.transport.close()
self._error_handler = None

@ -0,0 +1,895 @@
import asyncio
import binascii
import cgi
import collections
import datetime
import enum
import http.cookies
import io
import json
import math
import time
import warnings
from email.utils import parsedate
from types import MappingProxyType
from urllib.parse import parse_qsl, unquote, urlsplit
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from . import hdrs, multipart
from .helpers import reify, sentinel
from .protocol import Response as ResponseImpl
from .protocol import HttpVersion10, HttpVersion11
from .streams import EOF_MARKER
__all__ = (
'ContentCoding', 'Request', 'StreamResponse', 'Response',
'json_response'
)
class HeadersMixin:
_content_type = None
_content_dict = None
_stored_content_type = sentinel
def _parse_content_type(self, raw):
self._stored_content_type = raw
if raw is None:
# default value according to RFC 2616
self._content_type = 'application/octet-stream'
self._content_dict = {}
else:
self._content_type, self._content_dict = cgi.parse_header(raw)
@property
def content_type(self, _CONTENT_TYPE=hdrs.CONTENT_TYPE):
"""The value of content part for Content-Type HTTP header."""
raw = self.headers.get(_CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_type
@property
def charset(self, _CONTENT_TYPE=hdrs.CONTENT_TYPE):
"""The value of charset part for Content-Type HTTP header."""
raw = self.headers.get(_CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_dict.get('charset')
@property
def content_length(self, _CONTENT_LENGTH=hdrs.CONTENT_LENGTH):
"""The value of Content-Length HTTP header."""
l = self.headers.get(_CONTENT_LENGTH)
if l is None:
return None
else:
return int(l)
FileField = collections.namedtuple('Field', 'name filename file content_type')
class ContentCoding(enum.Enum):
# The content codings that we have support for.
#
# Additional registered codings are listed at:
# https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding
deflate = 'deflate'
gzip = 'gzip'
identity = 'identity'
############################################################
# HTTP Request
############################################################
class Request(dict, HeadersMixin):
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT,
hdrs.METH_TRACE, hdrs.METH_DELETE}
def __init__(self, app, message, payload, transport, reader, writer, *,
secure_proxy_ssl_header=None):
self._app = app
self._message = message
self._transport = transport
self._reader = reader
self._writer = writer
self._post = None
self._post_files_cache = None
# matchdict, route_name, handler
# or information about traversal lookup
self._match_info = None # initialized after route resolving
self._payload = payload
self._read_bytes = None
self._has_body = not payload.at_eof()
self._secure_proxy_ssl_header = secure_proxy_ssl_header
@reify
def scheme(self):
"""A string representing the scheme of the request.
'http' or 'https'.
"""
if self._transport.get_extra_info('sslcontext'):
return 'https'
secure_proxy_ssl_header = self._secure_proxy_ssl_header
if secure_proxy_ssl_header is not None:
header, value = secure_proxy_ssl_header
if self.headers.get(header) == value:
return 'https'
return 'http'
@reify
def method(self):
"""Read only property for getting HTTP method.
The value is upper-cased str like 'GET', 'POST', 'PUT' etc.
"""
return self._message.method
@reify
def version(self):
"""Read only property for getting HTTP version of request.
Returns aiohttp.protocol.HttpVersion instance.
"""
return self._message.version
@reify
def host(self):
"""Read only property for getting *HOST* header of request.
Returns str or None if HTTP request has no HOST header.
"""
return self._message.headers.get(hdrs.HOST)
@reify
def path_qs(self):
"""The URL including PATH_INFO and the query string.
E.g, /app/blog?id=10
"""
return self._message.path
@reify
def _splitted_path(self):
url = '{}://{}{}'.format(self.scheme, self.host, self.path_qs)
return urlsplit(url)
@reify
def raw_path(self):
""" The URL including raw *PATH INFO* without the host or scheme.
Warning, the path is unquoted and may contains non valid URL characters
E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters``
"""
return self._splitted_path.path
@reify
def path(self):
"""The URL including *PATH INFO* without the host or scheme.
E.g., ``/app/blog``
"""
return unquote(self.raw_path)
@reify
def query_string(self):
"""The query string in the URL.
E.g., id=10
"""
return self._splitted_path.query
@reify
def GET(self):
"""A multidict with all the variables in the query string.
Lazy property.
"""
return MultiDictProxy(MultiDict(parse_qsl(self.query_string,
keep_blank_values=True)))
@reify
def POST(self):
"""A multidict with all the variables in the POST parameters.
post() methods has to be called before using this attribute.
"""
if self._post is None:
raise RuntimeError("POST is not available before post()")
return self._post
@reify
def headers(self):
"""A case-insensitive multidict proxy with all headers."""
return CIMultiDictProxy(self._message.headers)
@reify
def raw_headers(self):
"""A sequence of pars for all headers."""
return tuple(self._message.raw_headers)
@reify
def if_modified_since(self, _IF_MODIFIED_SINCE=hdrs.IF_MODIFIED_SINCE):
"""The value of If-Modified-Since HTTP header, or None.
This header is represented as a `datetime` object.
"""
httpdate = self.headers.get(_IF_MODIFIED_SINCE)
if httpdate is not None:
timetuple = parsedate(httpdate)
if timetuple is not None:
return datetime.datetime(*timetuple[:6],
tzinfo=datetime.timezone.utc)
return None
@reify
def keep_alive(self):
"""Is keepalive enabled by client?"""
if self.version < HttpVersion10:
return False
else:
return not self._message.should_close
@property
def match_info(self):
"""Result of route resolving."""
return self._match_info
@property
def app(self):
"""Application instance."""
return self._app
@property
def transport(self):
"""Transport used for request processing."""
return self._transport
@reify
def cookies(self):
"""Return request cookies.
A read-only dictionary-like object.
"""
raw = self.headers.get(hdrs.COOKIE, '')
parsed = http.cookies.SimpleCookie(raw)
return MappingProxyType(
{key: val.value for key, val in parsed.items()})
@property
def content(self):
"""Return raw payload stream."""
return self._payload
@property
def has_body(self):
"""Return True if request has HTTP BODY, False otherwise."""
return self._has_body
@asyncio.coroutine
def release(self):
"""Release request.
Eat unread part of HTTP BODY if present.
"""
chunk = yield from self._payload.readany()
while chunk is not EOF_MARKER or chunk:
chunk = yield from self._payload.readany()
@asyncio.coroutine
def read(self):
"""Read request body if present.
Returns bytes object with full request content.
"""
if self._read_bytes is None:
body = bytearray()
while True:
chunk = yield from self._payload.readany()
body.extend(chunk)
if chunk is EOF_MARKER:
break
self._read_bytes = bytes(body)
return self._read_bytes
@asyncio.coroutine
def text(self):
"""Return BODY as text using encoding from .charset."""
bytes_body = yield from self.read()
encoding = self.charset or 'utf-8'
return bytes_body.decode(encoding)
@asyncio.coroutine
def json(self, *, loads=json.loads, loader=None):
"""Return BODY as JSON."""
if loader is not None:
warnings.warn(
"Using loader argument is deprecated, use loads instead",
DeprecationWarning)
loads = loader
body = yield from self.text()
return loads(body)
@asyncio.coroutine
def multipart(self, *, reader=multipart.MultipartReader):
"""Return async iterator to process BODY as multipart."""
return reader(self.headers, self.content)
@asyncio.coroutine
def post(self):
"""Return POST parameters."""
if self._post is not None:
return self._post
if self.method not in self.POST_METHODS:
self._post = MultiDictProxy(MultiDict())
return self._post
content_type = self.content_type
if (content_type not in ('',
'application/x-www-form-urlencoded',
'multipart/form-data')):
self._post = MultiDictProxy(MultiDict())
return self._post
if self.content_type.startswith('multipart/'):
warnings.warn('To process multipart requests use .multipart'
' coroutine instead.', DeprecationWarning)
body = yield from self.read()
content_charset = self.charset or 'utf-8'
environ = {'REQUEST_METHOD': self.method,
'CONTENT_LENGTH': str(len(body)),
'QUERY_STRING': '',
'CONTENT_TYPE': self.headers.get(hdrs.CONTENT_TYPE)}
fs = cgi.FieldStorage(fp=io.BytesIO(body),
environ=environ,
keep_blank_values=True,
encoding=content_charset)
supported_transfer_encoding = {
'base64': binascii.a2b_base64,
'quoted-printable': binascii.a2b_qp
}
out = MultiDict()
_count = 1
for field in fs.list or ():
transfer_encoding = field.headers.get(
hdrs.CONTENT_TRANSFER_ENCODING, None)
if field.filename:
ff = FileField(field.name,
field.filename,
field.file, # N.B. file closed error
field.type)
if self._post_files_cache is None:
self._post_files_cache = {}
self._post_files_cache[field.name+str(_count)] = field
_count += 1
out.add(field.name, ff)
else:
value = field.value
if transfer_encoding in supported_transfer_encoding:
# binascii accepts bytes
value = value.encode('utf-8')
value = supported_transfer_encoding[
transfer_encoding](value)
out.add(field.name, value)
self._post = MultiDictProxy(out)
return self._post
def copy(self):
raise NotImplementedError
def __repr__(self):
ascii_encodable_path = self.path.encode('ascii', 'backslashreplace') \
.decode('ascii')
return "<{} {} {} >".format(self.__class__.__name__,
self.method, ascii_encodable_path)
############################################################
# HTTP Response classes
############################################################
class StreamResponse(HeadersMixin):
def __init__(self, *, status=200, reason=None, headers=None):
self._body = None
self._keep_alive = None
self._chunked = False
self._chunk_size = None
self._compression = False
self._compression_force = False
self._headers = CIMultiDict()
self._cookies = http.cookies.SimpleCookie()
self.set_status(status, reason)
self._req = None
self._resp_impl = None
self._eof_sent = False
self._tcp_nodelay = True
self._tcp_cork = False
if headers is not None:
self._headers.extend(headers)
self._parse_content_type(self._headers.get(hdrs.CONTENT_TYPE))
self._generate_content_type_header()
def _copy_cookies(self):
for cookie in self._cookies.values():
value = cookie.output(header='')[1:]
self.headers.add(hdrs.SET_COOKIE, value)
@property
def prepared(self):
return self._resp_impl is not None
@property
def started(self):
warnings.warn('use Response.prepared instead', DeprecationWarning)
return self.prepared
@property
def status(self):
return self._status
@property
def chunked(self):
return self._chunked
@property
def compression(self):
return self._compression
@property
def reason(self):
return self._reason
def set_status(self, status, reason=None):
self._status = int(status)
if reason is None:
reason = ResponseImpl.calc_reason(status)
self._reason = reason
@property
def keep_alive(self):
return self._keep_alive
def force_close(self):
self._keep_alive = False
def enable_chunked_encoding(self, chunk_size=None):
"""Enables automatic chunked transfer encoding."""
self._chunked = True
self._chunk_size = chunk_size
def enable_compression(self, force=None):
"""Enables response compression encoding."""
# Backwards compatibility for when force was a bool <0.17.
if type(force) == bool:
force = ContentCoding.deflate if force else ContentCoding.identity
elif force is not None:
assert isinstance(force, ContentCoding), ("force should one of "
"None, bool or "
"ContentEncoding")
self._compression = True
self._compression_force = force
@property
def headers(self):
return self._headers
@property
def cookies(self):
return self._cookies
def set_cookie(self, name, value, *, expires=None,
domain=None, max_age=None, path='/',
secure=None, httponly=None, version=None):
"""Set or update response cookie.
Sets new cookie or updates existent with new value.
Also updates only those params which are not None.
"""
old = self._cookies.get(name)
if old is not None and old.coded_value == '':
# deleted cookie
self._cookies.pop(name, None)
self._cookies[name] = value
c = self._cookies[name]
if expires is not None:
c['expires'] = expires
elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
del c['expires']
if domain is not None:
c['domain'] = domain
if max_age is not None:
c['max-age'] = max_age
elif 'max-age' in c:
del c['max-age']
c['path'] = path
if secure is not None:
c['secure'] = secure
if httponly is not None:
c['httponly'] = httponly
if version is not None:
c['version'] = version
def del_cookie(self, name, *, domain=None, path='/'):
"""Delete cookie.
Creates new empty expired cookie.
"""
# TODO: do we need domain/path here?
self._cookies.pop(name, None)
self.set_cookie(name, '', max_age=0,
expires="Thu, 01 Jan 1970 00:00:00 GMT",
domain=domain, path=path)
@property
def content_length(self):
# Just a placeholder for adding setter
return super().content_length
@content_length.setter
def content_length(self, value):
if value is not None:
value = int(value)
# TODO: raise error if chunked enabled
self.headers[hdrs.CONTENT_LENGTH] = str(value)
else:
self.headers.pop(hdrs.CONTENT_LENGTH, None)
@property
def content_type(self):
# Just a placeholder for adding setter
return super().content_type
@content_type.setter
def content_type(self, value):
self.content_type # read header values if needed
self._content_type = str(value)
self._generate_content_type_header()
@property
def charset(self):
# Just a placeholder for adding setter
return super().charset
@charset.setter
def charset(self, value):
ctype = self.content_type # read header values if needed
if ctype == 'application/octet-stream':
raise RuntimeError("Setting charset for application/octet-stream "
"doesn't make sense, setup content_type first")
if value is None:
self._content_dict.pop('charset', None)
else:
self._content_dict['charset'] = str(value).lower()
self._generate_content_type_header()
@property
def last_modified(self, _LAST_MODIFIED=hdrs.LAST_MODIFIED):
"""The value of Last-Modified HTTP header, or None.
This header is represented as a `datetime` object.
"""
httpdate = self.headers.get(_LAST_MODIFIED)
if httpdate is not None:
timetuple = parsedate(httpdate)
if timetuple is not None:
return datetime.datetime(*timetuple[:6],
tzinfo=datetime.timezone.utc)
return None
@last_modified.setter
def last_modified(self, value):
if value is None:
self.headers.pop(hdrs.LAST_MODIFIED, None)
elif isinstance(value, (int, float)):
self.headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
elif isinstance(value, datetime.datetime):
self.headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
elif isinstance(value, str):
self.headers[hdrs.LAST_MODIFIED] = value
@property
def tcp_nodelay(self):
return self._tcp_nodelay
def set_tcp_nodelay(self, value):
value = bool(value)
self._tcp_nodelay = value
if value:
self._tcp_cork = False
if self._resp_impl is None:
return
if value:
self._resp_impl.transport.set_tcp_cork(False)
self._resp_impl.transport.set_tcp_nodelay(value)
@property
def tcp_cork(self):
return self._tcp_cork
def set_tcp_cork(self, value):
value = bool(value)
self._tcp_cork = value
if value:
self._tcp_nodelay = False
if self._resp_impl is None:
return
if value:
self._resp_impl.transport.set_tcp_nodelay(False)
self._resp_impl.transport.set_tcp_cork(value)
def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
params = '; '.join("%s=%s" % i for i in self._content_dict.items())
if params:
ctype = self._content_type + '; ' + params
else:
ctype = self._content_type
self.headers[CONTENT_TYPE] = ctype
def _start_pre_check(self, request):
if self._resp_impl is not None:
if self._req is not request:
raise RuntimeError(
"Response has been started with different request.")
else:
return self._resp_impl
else:
return None
def _do_start_compression(self, coding):
if coding != ContentCoding.identity:
self.headers[hdrs.CONTENT_ENCODING] = coding.value
self._resp_impl.add_compression_filter(coding.value)
self.content_length = None
def _start_compression(self, request):
if self._compression_force:
self._do_start_compression(self._compression_force)
else:
accept_encoding = request.headers.get(
hdrs.ACCEPT_ENCODING, '').lower()
for coding in ContentCoding:
if coding.value in accept_encoding:
self._do_start_compression(coding)
return
def start(self, request):
warnings.warn('use .prepare(request) instead', DeprecationWarning)
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
return self._start(request)
@asyncio.coroutine
def prepare(self, request):
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
yield from request.app.on_response_prepare.send(request, self)
return self._start(request)
def _start(self, request):
self._req = request
keep_alive = self._keep_alive
if keep_alive is None:
keep_alive = request.keep_alive
self._keep_alive = keep_alive
resp_impl = self._resp_impl = ResponseImpl(
request._writer,
self._status,
request.version,
not keep_alive,
self._reason)
self._copy_cookies()
if self._compression:
self._start_compression(request)
if self._chunked:
if request.version != HttpVersion11:
raise RuntimeError("Using chunked encoding is forbidden "
"for HTTP/{0.major}.{0.minor}".format(
request.version))
resp_impl.enable_chunked_encoding()
if self._chunk_size:
resp_impl.add_chunking_filter(self._chunk_size)
headers = self.headers.items()
for key, val in headers:
resp_impl.add_header(key, val)
resp_impl.transport.set_tcp_nodelay(self._tcp_nodelay)
resp_impl.transport.set_tcp_cork(self._tcp_cork)
self._send_headers(resp_impl)
return resp_impl
def _send_headers(self, resp_impl):
# Durty hack required for
# https://github.com/KeepSafe/aiohttp/issues/1093
# File sender may override it
resp_impl.send_headers()
def write(self, data):
assert isinstance(data, (bytes, bytearray, memoryview)), \
"data argument must be byte-ish (%r)" % type(data)
if self._eof_sent:
raise RuntimeError("Cannot call write() after write_eof()")
if self._resp_impl is None:
raise RuntimeError("Cannot call write() before start()")
if data:
return self._resp_impl.write(data)
else:
return ()
@asyncio.coroutine
def drain(self):
if self._resp_impl is None:
raise RuntimeError("Response has not been started")
yield from self._resp_impl.transport.drain()
@asyncio.coroutine
def write_eof(self):
if self._eof_sent:
return
if self._resp_impl is None:
raise RuntimeError("Response has not been started")
yield from self._resp_impl.write_eof()
self._eof_sent = True
def __repr__(self):
if self.started:
info = "{} {} ".format(self._req.method, self._req.path)
else:
info = "not started"
return "<{} {} {}>".format(self.__class__.__name__,
self.reason, info)
class Response(StreamResponse):
def __init__(self, *, body=None, status=200,
reason=None, text=None, headers=None, content_type=None,
charset=None):
if body is not None and text is not None:
raise ValueError("body and text are not allowed together")
if headers is None:
headers = CIMultiDict()
elif not isinstance(headers, (CIMultiDict, CIMultiDictProxy)):
headers = CIMultiDict(headers)
if content_type is not None and ";" in content_type:
raise ValueError("charset must not be in content_type "
"argument")
if text is not None:
if hdrs.CONTENT_TYPE in headers:
if content_type or charset:
raise ValueError("passing both Content-Type header and "
"content_type or charset params "
"is forbidden")
else:
# fast path for filling headers
if not isinstance(text, str):
raise TypeError("text argument must be str (%r)" %
type(text))
if content_type is None:
content_type = 'text/plain'
if charset is None:
charset = 'utf-8'
headers[hdrs.CONTENT_TYPE] = (
content_type + '; charset=' + charset)
body = text.encode(charset)
text = None
else:
if hdrs.CONTENT_TYPE in headers:
if content_type is not None or charset is not None:
raise ValueError("passing both Content-Type header and "
"content_type or charset params "
"is forbidden")
else:
if content_type is not None:
if charset is not None:
content_type += '; charset=' + charset
headers[hdrs.CONTENT_TYPE] = content_type
super().__init__(status=status, reason=reason, headers=headers)
self.set_tcp_cork(True)
if text is not None:
self.text = text
else:
self.body = body
@property
def body(self):
return self._body
@body.setter
def body(self, body):
if body is not None and not isinstance(body, bytes):
raise TypeError("body argument must be bytes (%r)" % type(body))
self._body = body
if body is not None:
self.content_length = len(body)
else:
self.content_length = 0
@property
def text(self):
if self._body is None:
return None
return self._body.decode(self.charset or 'utf-8')
@text.setter
def text(self, text):
if text is not None and not isinstance(text, str):
raise TypeError("text argument must be str (%r)" % type(text))
if self.content_type == 'application/octet-stream':
self.content_type = 'text/plain'
if self.charset is None:
self.charset = 'utf-8'
self.body = text.encode(self.charset)
@asyncio.coroutine
def write_eof(self):
try:
body = self._body
if (body is not None and
self._req.method != hdrs.METH_HEAD and
self._status not in [204, 304]):
self.write(body)
finally:
self.set_tcp_nodelay(True)
yield from super().write_eof()
def json_response(data=sentinel, *, text=None, body=None, status=200,
reason=None, headers=None, content_type='application/json',
dumps=json.dumps):
if data is not sentinel:
if text or body:
raise ValueError(
"only one of data, text, or body should be specified"
)
else:
text = dumps(data)
return Response(text=text, body=body, status=status, reason=reason,
headers=headers, content_type=content_type)

@ -1,754 +0,0 @@
import asyncio
import datetime
import io
import re
import socket
import string
import tempfile
import types
import warnings
from email.utils import parsedate
from http.cookies import SimpleCookie
from types import MappingProxyType
from typing import ( # noqa
TYPE_CHECKING,
Any,
Dict,
Iterator,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
cast,
)
from urllib.parse import parse_qsl
import attr
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from yarl import URL
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel
from .http_parser import RawRequestMessage
from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
DEFAULT_JSON_DECODER,
JSONDecoder,
LooseHeaders,
RawHeaders,
StrOrURL,
)
from .web_exceptions import HTTPRequestEntityTooLarge
from .web_response import StreamResponse
__all__ = ('BaseRequest', 'FileField', 'Request')
if TYPE_CHECKING: # pragma: no cover
from .web_app import Application # noqa
from .web_urldispatcher import UrlMappingMatchInfo # noqa
from .web_protocol import RequestHandler # noqa
@attr.s(frozen=True, slots=True)
class FileField:
name = attr.ib(type=str)
filename = attr.ib(type=str)
file = attr.ib(type=io.BufferedReader)
content_type = attr.ib(type=str)
headers = attr.ib(type=CIMultiDictProxy) # type: CIMultiDictProxy[str]
_TCHAR = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-"
# '-' at the end to prevent interpretation as range in a char class
_TOKEN = r'[{tchar}]+'.format(tchar=_TCHAR)
_QDTEXT = r'[{}]'.format(
r''.join(chr(c) for c in (0x09, 0x20, 0x21) + tuple(range(0x23, 0x7F))))
# qdtext includes 0x5C to escape 0x5D ('\]')
# qdtext excludes obs-text (because obsoleted, and encoding not specified)
_QUOTED_PAIR = r'\\[\t !-~]'
_QUOTED_STRING = r'"(?:{quoted_pair}|{qdtext})*"'.format(
qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR)
_FORWARDED_PAIR = (
r'({token})=({token}|{quoted_string})(:\d{{1,4}})?'.format(
token=_TOKEN,
quoted_string=_QUOTED_STRING))
_QUOTED_PAIR_REPLACE_RE = re.compile(r'\\([\t !-~])')
# same pattern as _QUOTED_PAIR but contains a capture group
_FORWARDED_PAIR_RE = re.compile(_FORWARDED_PAIR)
############################################################
# HTTP Request
############################################################
class BaseRequest(MutableMapping[str, Any], HeadersMixin):
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT,
hdrs.METH_TRACE, hdrs.METH_DELETE}
ATTRS = HeadersMixin.ATTRS | frozenset([
'_message', '_protocol', '_payload_writer', '_payload', '_headers',
'_method', '_version', '_rel_url', '_post', '_read_bytes',
'_state', '_cache', '_task', '_client_max_size', '_loop',
'_transport_sslcontext', '_transport_peername'])
def __init__(self, message: RawRequestMessage,
payload: StreamReader, protocol: 'RequestHandler',
payload_writer: AbstractStreamWriter,
task: 'asyncio.Task[None]',
loop: asyncio.AbstractEventLoop,
*, client_max_size: int=1024**2,
state: Optional[Dict[str, Any]]=None,
scheme: Optional[str]=None,
host: Optional[str]=None,
remote: Optional[str]=None) -> None:
if state is None:
state = {}
self._message = message
self._protocol = protocol
self._payload_writer = payload_writer
self._payload = payload
self._headers = message.headers
self._method = message.method
self._version = message.version
self._rel_url = message.url
self._post = None # type: Optional[MultiDictProxy[Union[str, bytes, FileField]]] # noqa
self._read_bytes = None # type: Optional[bytes]
self._state = state
self._cache = {} # type: Dict[str, Any]
self._task = task
self._client_max_size = client_max_size
self._loop = loop
transport = self._protocol.transport
assert transport is not None
self._transport_sslcontext = transport.get_extra_info('sslcontext')
self._transport_peername = transport.get_extra_info('peername')
if scheme is not None:
self._cache['scheme'] = scheme
if host is not None:
self._cache['host'] = host
if remote is not None:
self._cache['remote'] = remote
def clone(self, *, method: str=sentinel, rel_url: StrOrURL=sentinel,
headers: LooseHeaders=sentinel, scheme: str=sentinel,
host: str=sentinel,
remote: str=sentinel) -> 'BaseRequest':
"""Clone itself with replacement some attributes.
Creates and returns a new instance of Request object. If no parameters
are given, an exact copy is returned. If a parameter is not passed, it
will reuse the one from the current request object.
"""
if self._read_bytes:
raise RuntimeError("Cannot clone request "
"after reading its content")
dct = {} # type: Dict[str, Any]
if method is not sentinel:
dct['method'] = method
if rel_url is not sentinel:
new_url = URL(rel_url)
dct['url'] = new_url
dct['path'] = str(new_url)
if headers is not sentinel:
# a copy semantic
dct['headers'] = CIMultiDictProxy(CIMultiDict(headers))
dct['raw_headers'] = tuple((k.encode('utf-8'), v.encode('utf-8'))
for k, v in headers.items())
message = self._message._replace(**dct)
kwargs = {}
if scheme is not sentinel:
kwargs['scheme'] = scheme
if host is not sentinel:
kwargs['host'] = host
if remote is not sentinel:
kwargs['remote'] = remote
return self.__class__(
message,
self._payload,
self._protocol,
self._payload_writer,
self._task,
self._loop,
client_max_size=self._client_max_size,
state=self._state.copy(),
**kwargs)
@property
def task(self) -> 'asyncio.Task[None]':
return self._task
@property
def protocol(self) -> 'RequestHandler':
return self._protocol
@property
def transport(self) -> Optional[asyncio.Transport]:
if self._protocol is None:
return None
return self._protocol.transport
@property
def writer(self) -> AbstractStreamWriter:
return self._payload_writer
@reify
def message(self) -> RawRequestMessage:
warnings.warn("Request.message is deprecated",
DeprecationWarning,
stacklevel=3)
return self._message
@reify
def rel_url(self) -> URL:
return self._rel_url
@reify
def loop(self) -> asyncio.AbstractEventLoop:
warnings.warn("request.loop property is deprecated",
DeprecationWarning,
stacklevel=2)
return self._loop
# MutableMapping API
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
return iter(self._state)
########
@reify
def secure(self) -> bool:
"""A bool indicating if the request is handled with SSL."""
return self.scheme == 'https'
@reify
def forwarded(self) -> Tuple[Mapping[str, str], ...]:
"""A tuple containing all parsed Forwarded header(s).
Makes an effort to parse Forwarded headers as specified by RFC 7239:
- It adds one (immutable) dictionary per Forwarded 'field-value', ie
per proxy. The element corresponds to the data in the Forwarded
field-value added by the first proxy encountered by the client. Each
subsequent item corresponds to those added by later proxies.
- It checks that every value has valid syntax in general as specified
in section 4: either a 'token' or a 'quoted-string'.
- It un-escapes found escape sequences.
- It does NOT validate 'by' and 'for' contents as specified in section
6.
- It does NOT validate 'host' contents (Host ABNF).
- It does NOT validate 'proto' contents for valid URI scheme names.
Returns a tuple containing one or more immutable dicts
"""
elems = []
for field_value in self._message.headers.getall(hdrs.FORWARDED, ()):
length = len(field_value)
pos = 0
need_separator = False
elem = {} # type: Dict[str, str]
elems.append(types.MappingProxyType(elem))
while 0 <= pos < length:
match = _FORWARDED_PAIR_RE.match(field_value, pos)
if match is not None: # got a valid forwarded-pair
if need_separator:
# bad syntax here, skip to next comma
pos = field_value.find(',', pos)
else:
name, value, port = match.groups()
if value[0] == '"':
# quoted string: remove quotes and unescape
value = _QUOTED_PAIR_REPLACE_RE.sub(r'\1',
value[1:-1])
if port:
value += port
elem[name.lower()] = value
pos += len(match.group(0))
need_separator = True
elif field_value[pos] == ',': # next forwarded-element
need_separator = False
elem = {}
elems.append(types.MappingProxyType(elem))
pos += 1
elif field_value[pos] == ';': # next forwarded-pair
need_separator = False
pos += 1
elif field_value[pos] in ' \t':
# Allow whitespace even between forwarded-pairs, though
# RFC 7239 doesn't. This simplifies code and is in line
# with Postel's law.
pos += 1
else:
# bad syntax here, skip to next comma
pos = field_value.find(',', pos)
return tuple(elems)
@reify
def scheme(self) -> str:
"""A string representing the scheme of the request.
Hostname is resolved in this order:
- overridden value by .clone(scheme=new_scheme) call.
- type of connection to peer: HTTPS if socket is SSL, HTTP otherwise.
'http' or 'https'.
"""
if self._transport_sslcontext:
return 'https'
else:
return 'http'
@reify
def method(self) -> str:
"""Read only property for getting HTTP method.
The value is upper-cased str like 'GET', 'POST', 'PUT' etc.
"""
return self._method
@reify
def version(self) -> Tuple[int, int]:
"""Read only property for getting HTTP version of request.
Returns aiohttp.protocol.HttpVersion instance.
"""
return self._version
@reify
def host(self) -> str:
"""Hostname of the request.
Hostname is resolved in this order:
- overridden value by .clone(host=new_host) call.
- HOST HTTP header
- socket.getfqdn() value
"""
host = self._message.headers.get(hdrs.HOST)
if host is not None:
return host
else:
return socket.getfqdn()
@reify
def remote(self) -> Optional[str]:
"""Remote IP of client initiated HTTP request.
The IP is resolved in this order:
- overridden value by .clone(remote=new_remote) call.
- peername of opened socket
"""
if isinstance(self._transport_peername, (list, tuple)):
return self._transport_peername[0]
else:
return self._transport_peername
@reify
def url(self) -> URL:
url = URL.build(scheme=self.scheme, host=self.host)
return url.join(self._rel_url)
@reify
def path(self) -> str:
"""The URL including *PATH INFO* without the host or scheme.
E.g., ``/app/blog``
"""
return self._rel_url.path
@reify
def path_qs(self) -> str:
"""The URL including PATH_INFO and the query string.
E.g, /app/blog?id=10
"""
return str(self._rel_url)
@reify
def raw_path(self) -> str:
""" The URL including raw *PATH INFO* without the host or scheme.
Warning, the path is unquoted and may contains non valid URL characters
E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters``
"""
return self._message.path
@reify
def query(self) -> 'MultiDictProxy[str]':
"""A multidict with all the variables in the query string."""
return self._rel_url.query
@reify
def query_string(self) -> str:
"""The query string in the URL.
E.g., id=10
"""
return self._rel_url.query_string
@reify
def headers(self) -> 'CIMultiDictProxy[str]':
"""A case-insensitive multidict proxy with all headers."""
return self._headers
@reify
def raw_headers(self) -> RawHeaders:
"""A sequence of pairs for all headers."""
return self._message.raw_headers
@staticmethod
def _http_date(_date_str: str) -> Optional[datetime.datetime]:
"""Process a date string, return a datetime object
"""
if _date_str is not None:
timetuple = parsedate(_date_str)
if timetuple is not None:
return datetime.datetime(*timetuple[:6],
tzinfo=datetime.timezone.utc)
return None
@reify
def if_modified_since(self) -> Optional[datetime.datetime]:
"""The value of If-Modified-Since HTTP header, or None.
This header is represented as a `datetime` object.
"""
return self._http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE))
@reify
def if_unmodified_since(self) -> Optional[datetime.datetime]:
"""The value of If-Unmodified-Since HTTP header, or None.
This header is represented as a `datetime` object.
"""
return self._http_date(self.headers.get(hdrs.IF_UNMODIFIED_SINCE))
@reify
def if_range(self) -> Optional[datetime.datetime]:
"""The value of If-Range HTTP header, or None.
This header is represented as a `datetime` object.
"""
return self._http_date(self.headers.get(hdrs.IF_RANGE))
@reify
def keep_alive(self) -> bool:
"""Is keepalive enabled by client?"""
return not self._message.should_close
@reify
def cookies(self) -> Mapping[str, str]:
"""Return request cookies.
A read-only dictionary-like object.
"""
raw = self.headers.get(hdrs.COOKIE, '')
parsed = SimpleCookie(raw)
return MappingProxyType(
{key: val.value for key, val in parsed.items()})
@reify
def http_range(self) -> slice:
"""The content of Range HTTP header.
Return a slice instance.
"""
rng = self._headers.get(hdrs.RANGE)
start, end = None, None
if rng is not None:
try:
pattern = r'^bytes=(\d*)-(\d*)$'
start, end = re.findall(pattern, rng)[0]
except IndexError: # pattern was not found in header
raise ValueError("range not in acceptable format")
end = int(end) if end else None
start = int(start) if start else None
if start is None and end is not None:
# end with no start is to return tail of content
start = -end
end = None
if start is not None and end is not None:
# end is inclusive in range header, exclusive for slice
end += 1
if start >= end:
raise ValueError('start cannot be after end')
if start is end is None: # No valid range supplied
raise ValueError('No start or end of range specified')
return slice(start, end, 1)
@reify
def content(self) -> StreamReader:
"""Return raw payload stream."""
return self._payload
@property
def has_body(self) -> bool:
"""Return True if request's HTTP BODY can be read, False otherwise."""
warnings.warn(
"Deprecated, use .can_read_body #2005",
DeprecationWarning, stacklevel=2)
return not self._payload.at_eof()
@property
def can_read_body(self) -> bool:
"""Return True if request's HTTP BODY can be read, False otherwise."""
return not self._payload.at_eof()
@reify
def body_exists(self) -> bool:
"""Return True if request has HTTP BODY, False otherwise."""
return type(self._payload) is not EmptyStreamReader
async def release(self) -> None:
"""Release request.
Eat unread part of HTTP BODY if present.
"""
while not self._payload.at_eof():
await self._payload.readany()
async def read(self) -> bytes:
"""Read request body if present.
Returns bytes object with full request content.
"""
if self._read_bytes is None:
body = bytearray()
while True:
chunk = await self._payload.readany()
body.extend(chunk)
if self._client_max_size:
body_size = len(body)
if body_size >= self._client_max_size:
raise HTTPRequestEntityTooLarge(
max_size=self._client_max_size,
actual_size=body_size
)
if not chunk:
break
self._read_bytes = bytes(body)
return self._read_bytes
async def text(self) -> str:
"""Return BODY as text using encoding from .charset."""
bytes_body = await self.read()
encoding = self.charset or 'utf-8'
return bytes_body.decode(encoding)
async def json(self, *, loads: JSONDecoder=DEFAULT_JSON_DECODER) -> Any:
"""Return BODY as JSON."""
body = await self.text()
return loads(body)
async def multipart(self) -> MultipartReader:
"""Return async iterator to process BODY as multipart."""
return MultipartReader(self._headers, self._payload)
async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]':
"""Return POST parameters."""
if self._post is not None:
return self._post
if self._method not in self.POST_METHODS:
self._post = MultiDictProxy(MultiDict())
return self._post
content_type = self.content_type
if (content_type not in ('',
'application/x-www-form-urlencoded',
'multipart/form-data')):
self._post = MultiDictProxy(MultiDict())
return self._post
out = MultiDict() # type: MultiDict[Union[str, bytes, FileField]]
if content_type == 'multipart/form-data':
multipart = await self.multipart()
max_size = self._client_max_size
field = await multipart.next()
while field is not None:
size = 0
field_ct = field.headers.get(hdrs.CONTENT_TYPE)
if isinstance(field, BodyPartReader):
if field.filename and field_ct:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)
ff = FileField(field.name, field.filename,
cast(io.BufferedReader, tmp),
field_ct, field.headers)
out.add(field.name, ff)
else:
# deal with ordinary data
value = await field.read(decode=True)
if field_ct is None or \
field_ct.startswith('text/'):
charset = field.get_charset(default='utf-8')
out.add(field.name, value.decode(charset))
else:
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
else:
raise ValueError(
'To decode nested multipart you need '
'to use custom reader',
)
field = await multipart.next()
else:
data = await self.read()
if data:
charset = self.charset or 'utf-8'
out.extend(
parse_qsl(
data.rstrip().decode(charset),
keep_blank_values=True,
encoding=charset))
self._post = MultiDictProxy(out)
return self._post
def __repr__(self) -> str:
ascii_encodable_path = self.path.encode('ascii', 'backslashreplace') \
.decode('ascii')
return "<{} {} {} >".format(self.__class__.__name__,
self._method, ascii_encodable_path)
def __eq__(self, other: object) -> bool:
return id(self) == id(other)
def __bool__(self) -> bool:
return True
async def _prepare_hook(self, response: StreamResponse) -> None:
return
class Request(BaseRequest):
ATTRS = BaseRequest.ATTRS | frozenset(['_match_info'])
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# matchdict, route_name, handler
# or information about traversal lookup
# initialized after route resolving
self._match_info = None # type: Optional[UrlMappingMatchInfo]
if DEBUG:
def __setattr__(self, name: str, val: Any) -> None:
if name not in self.ATTRS:
warnings.warn("Setting custom {}.{} attribute "
"is discouraged".format(self.__class__.__name__,
name),
DeprecationWarning,
stacklevel=2)
super().__setattr__(name, val)
def clone(self, *, method: str=sentinel, rel_url:
StrOrURL=sentinel, headers: LooseHeaders=sentinel,
scheme: str=sentinel, host: str=sentinel, remote:
str=sentinel) -> 'Request':
ret = super().clone(method=method,
rel_url=rel_url,
headers=headers,
scheme=scheme,
host=host,
remote=remote)
new_ret = cast(Request, ret)
new_ret._match_info = self._match_info
return new_ret
@reify
def match_info(self) -> 'UrlMappingMatchInfo':
"""Result of route resolving."""
match_info = self._match_info
assert match_info is not None
return match_info
@property
def app(self) -> 'Application':
"""Application instance."""
match_info = self._match_info
assert match_info is not None
return match_info.current_app
@property
def config_dict(self) -> ChainMapProxy:
match_info = self._match_info
assert match_info is not None
lst = match_info.apps
app = self.app
idx = lst.index(app)
sublist = list(reversed(lst[:idx + 1]))
return ChainMapProxy(sublist)
async def _prepare_hook(self, response: StreamResponse) -> None:
match_info = self._match_info
if match_info is None:
return
for app in match_info._apps:
await app.on_response_prepare.send(self, response)

@ -1,717 +0,0 @@
import asyncio # noqa
import collections.abc # noqa
import datetime
import enum
import json
import math
import time
import warnings
import zlib
from concurrent.futures import Executor
from email.utils import parsedate
from http.cookies import SimpleCookie
from typing import ( # noqa
TYPE_CHECKING,
Any,
Dict,
Iterator,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
cast,
)
from multidict import CIMultiDict, istr
from . import hdrs, payload
from .abc import AbstractStreamWriter
from .helpers import HeadersMixin, rfc822_formatted_time, sentinel
from .http import RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .payload import Payload
from .typedefs import JSONEncoder, LooseHeaders
__all__ = ('ContentCoding', 'StreamResponse', 'Response', 'json_response')
if TYPE_CHECKING: # pragma: no cover
from .web_request import BaseRequest # noqa
BaseClass = MutableMapping[str, Any]
else:
BaseClass = collections.abc.MutableMapping
class ContentCoding(enum.Enum):
# The content codings that we have support for.
#
# Additional registered codings are listed at:
# https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding
deflate = 'deflate'
gzip = 'gzip'
identity = 'identity'
############################################################
# HTTP Response classes
############################################################
class StreamResponse(BaseClass, HeadersMixin):
_length_check = True
def __init__(self, *,
status: int=200,
reason: Optional[str]=None,
headers: Optional[LooseHeaders]=None) -> None:
self._body = None
self._keep_alive = None # type: Optional[bool]
self._chunked = False
self._compression = False
self._compression_force = None # type: Optional[ContentCoding]
self._cookies = SimpleCookie()
self._req = None # type: Optional[BaseRequest]
self._payload_writer = None # type: Optional[AbstractStreamWriter]
self._eof_sent = False
self._body_length = 0
self._state = {} # type: Dict[str, Any]
if headers is not None:
self._headers = CIMultiDict(headers) # type: CIMultiDict[str]
else:
self._headers = CIMultiDict()
self.set_status(status, reason)
@property
def prepared(self) -> bool:
return self._payload_writer is not None
@property
def task(self) -> 'asyncio.Task[None]':
return getattr(self._req, 'task', None)
@property
def status(self) -> int:
return self._status
@property
def chunked(self) -> bool:
return self._chunked
@property
def compression(self) -> bool:
return self._compression
@property
def reason(self) -> str:
return self._reason
def set_status(self, status: int,
reason: Optional[str]=None,
_RESPONSES: Mapping[int,
Tuple[str, str]]=RESPONSES) -> None:
assert not self.prepared, \
'Cannot change the response status code after ' \
'the headers have been sent'
self._status = int(status)
if reason is None:
try:
reason = _RESPONSES[self._status][0]
except Exception:
reason = ''
self._reason = reason
@property
def keep_alive(self) -> Optional[bool]:
return self._keep_alive
def force_close(self) -> None:
self._keep_alive = False
@property
def body_length(self) -> int:
return self._body_length
@property
def output_length(self) -> int:
warnings.warn('output_length is deprecated', DeprecationWarning)
assert self._payload_writer
return self._payload_writer.buffer_size
def enable_chunked_encoding(self, chunk_size: Optional[int]=None) -> None:
"""Enables automatic chunked transfer encoding."""
self._chunked = True
if hdrs.CONTENT_LENGTH in self._headers:
raise RuntimeError("You can't enable chunked encoding when "
"a content length is set")
if chunk_size is not None:
warnings.warn('Chunk size is deprecated #1615', DeprecationWarning)
def enable_compression(self,
force: Optional[Union[bool, ContentCoding]]=None
) -> None:
"""Enables response compression encoding."""
# Backwards compatibility for when force was a bool <0.17.
if type(force) == bool:
force = ContentCoding.deflate if force else ContentCoding.identity
warnings.warn("Using boolean for force is deprecated #3318",
DeprecationWarning)
elif force is not None:
assert isinstance(force, ContentCoding), ("force should one of "
"None, bool or "
"ContentEncoding")
self._compression = True
self._compression_force = force
@property
def headers(self) -> 'CIMultiDict[str]':
return self._headers
@property
def cookies(self) -> SimpleCookie:
return self._cookies
def set_cookie(self, name: str, value: str, *,
expires: Optional[str]=None,
domain: Optional[str]=None,
max_age: Optional[Union[int, str]]=None,
path: str='/',
secure: Optional[str]=None,
httponly: Optional[str]=None,
version: Optional[str]=None) -> None:
"""Set or update response cookie.
Sets new cookie or updates existent with new value.
Also updates only those params which are not None.
"""
old = self._cookies.get(name)
if old is not None and old.coded_value == '':
# deleted cookie
self._cookies.pop(name, None)
self._cookies[name] = value
c = self._cookies[name]
if expires is not None:
c['expires'] = expires
elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
del c['expires']
if domain is not None:
c['domain'] = domain
if max_age is not None:
c['max-age'] = str(max_age)
elif 'max-age' in c:
del c['max-age']
c['path'] = path
if secure is not None:
c['secure'] = secure
if httponly is not None:
c['httponly'] = httponly
if version is not None:
c['version'] = version
def del_cookie(self, name: str, *,
domain: Optional[str]=None,
path: str='/') -> None:
"""Delete cookie.
Creates new empty expired cookie.
"""
# TODO: do we need domain/path here?
self._cookies.pop(name, None)
self.set_cookie(name, '', max_age=0,
expires="Thu, 01 Jan 1970 00:00:00 GMT",
domain=domain, path=path)
@property
def content_length(self) -> Optional[int]:
# Just a placeholder for adding setter
return super().content_length
@content_length.setter
def content_length(self, value: Optional[int]) -> None:
if value is not None:
value = int(value)
if self._chunked:
raise RuntimeError("You can't set content length when "
"chunked encoding is enable")
self._headers[hdrs.CONTENT_LENGTH] = str(value)
else:
self._headers.pop(hdrs.CONTENT_LENGTH, None)
@property
def content_type(self) -> str:
# Just a placeholder for adding setter
return super().content_type
@content_type.setter
def content_type(self, value: str) -> None:
self.content_type # read header values if needed
self._content_type = str(value)
self._generate_content_type_header()
@property
def charset(self) -> Optional[str]:
# Just a placeholder for adding setter
return super().charset
@charset.setter
def charset(self, value: Optional[str]) -> None:
ctype = self.content_type # read header values if needed
if ctype == 'application/octet-stream':
raise RuntimeError("Setting charset for application/octet-stream "
"doesn't make sense, setup content_type first")
assert self._content_dict is not None
if value is None:
self._content_dict.pop('charset', None)
else:
self._content_dict['charset'] = str(value).lower()
self._generate_content_type_header()
@property
def last_modified(self) -> Optional[datetime.datetime]:
"""The value of Last-Modified HTTP header, or None.
This header is represented as a `datetime` object.
"""
httpdate = self._headers.get(hdrs.LAST_MODIFIED)
if httpdate is not None:
timetuple = parsedate(httpdate)
if timetuple is not None:
return datetime.datetime(*timetuple[:6],
tzinfo=datetime.timezone.utc)
return None
@last_modified.setter
def last_modified(self,
value: Optional[
Union[int, float, datetime.datetime, str]]) -> None:
if value is None:
self._headers.pop(hdrs.LAST_MODIFIED, None)
elif isinstance(value, (int, float)):
self._headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
elif isinstance(value, datetime.datetime):
self._headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
elif isinstance(value, str):
self._headers[hdrs.LAST_MODIFIED] = value
def _generate_content_type_header(
self,
CONTENT_TYPE: istr=hdrs.CONTENT_TYPE) -> None:
assert self._content_dict is not None
assert self._content_type is not None
params = '; '.join("{}={}".format(k, v)
for k, v in self._content_dict.items())
if params:
ctype = self._content_type + '; ' + params
else:
ctype = self._content_type
self._headers[CONTENT_TYPE] = ctype
async def _do_start_compression(self, coding: ContentCoding) -> None:
if coding != ContentCoding.identity:
assert self._payload_writer is not None
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._payload_writer.enable_compression(coding.value)
# Compressed payload may have different content length,
# remove the header
self._headers.popall(hdrs.CONTENT_LENGTH, None)
async def _start_compression(self, request: 'BaseRequest') -> None:
if self._compression_force:
await self._do_start_compression(self._compression_force)
else:
accept_encoding = request.headers.get(
hdrs.ACCEPT_ENCODING, '').lower()
for coding in ContentCoding:
if coding.value in accept_encoding:
await self._do_start_compression(coding)
return
async def prepare(
self,
request: 'BaseRequest'
) -> Optional[AbstractStreamWriter]:
if self._eof_sent:
return None
if self._payload_writer is not None:
return self._payload_writer
await request._prepare_hook(self)
return await self._start(request)
async def _start(self, request: 'BaseRequest') -> AbstractStreamWriter:
self._req = request
keep_alive = self._keep_alive
if keep_alive is None:
keep_alive = request.keep_alive
self._keep_alive = keep_alive
version = request.version
writer = self._payload_writer = request._payload_writer
headers = self._headers
for cookie in self._cookies.values():
value = cookie.output(header='')[1:]
headers.add(hdrs.SET_COOKIE, value)
if self._compression:
await self._start_compression(request)
if self._chunked:
if version != HttpVersion11:
raise RuntimeError(
"Using chunked encoding is forbidden "
"for HTTP/{0.major}.{0.minor}".format(request.version))
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = 'chunked'
if hdrs.CONTENT_LENGTH in headers:
del headers[hdrs.CONTENT_LENGTH]
elif self._length_check:
writer.length = self.content_length
if writer.length is None:
if version >= HttpVersion11:
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = 'chunked'
if hdrs.CONTENT_LENGTH in headers:
del headers[hdrs.CONTENT_LENGTH]
else:
keep_alive = False
headers.setdefault(hdrs.CONTENT_TYPE, 'application/octet-stream')
headers.setdefault(hdrs.DATE, rfc822_formatted_time())
headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE)
# connection header
if hdrs.CONNECTION not in headers:
if keep_alive:
if version == HttpVersion10:
headers[hdrs.CONNECTION] = 'keep-alive'
else:
if version == HttpVersion11:
headers[hdrs.CONNECTION] = 'close'
# status line
status_line = 'HTTP/{}.{} {} {}'.format(
version[0], version[1], self._status, self._reason)
await writer.write_headers(status_line, headers)
return writer
async def write(self, data: bytes) -> None:
assert isinstance(data, (bytes, bytearray, memoryview)), \
"data argument must be byte-ish (%r)" % type(data)
if self._eof_sent:
raise RuntimeError("Cannot call write() after write_eof()")
if self._payload_writer is None:
raise RuntimeError("Cannot call write() before prepare()")
await self._payload_writer.write(data)
async def drain(self) -> None:
assert not self._eof_sent, "EOF has already been sent"
assert self._payload_writer is not None, \
"Response has not been started"
warnings.warn("drain method is deprecated, use await resp.write()",
DeprecationWarning,
stacklevel=2)
await self._payload_writer.drain()
async def write_eof(self, data: bytes=b'') -> None:
assert isinstance(data, (bytes, bytearray, memoryview)), \
"data argument must be byte-ish (%r)" % type(data)
if self._eof_sent:
return
assert self._payload_writer is not None, \
"Response has not been started"
await self._payload_writer.write_eof(data)
self._eof_sent = True
self._req = None
self._body_length = self._payload_writer.output_size
self._payload_writer = None
def __repr__(self) -> str:
if self._eof_sent:
info = "eof"
elif self.prepared:
assert self._req is not None
info = "{} {} ".format(self._req.method, self._req.path)
else:
info = "not prepared"
return "<{} {} {}>".format(self.__class__.__name__,
self.reason, info)
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
return iter(self._state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, other: object) -> bool:
return self is other
class Response(StreamResponse):
def __init__(self, *,
body: Any=None,
status: int=200,
reason: Optional[str]=None,
text: Optional[str]=None,
headers: Optional[LooseHeaders]=None,
content_type: Optional[str]=None,
charset: Optional[str]=None,
zlib_executor_size: Optional[int]=None,
zlib_executor: Executor=None) -> None:
if body is not None and text is not None:
raise ValueError("body and text are not allowed together")
if headers is None:
real_headers = CIMultiDict() # type: CIMultiDict[str]
elif not isinstance(headers, CIMultiDict):
real_headers = CIMultiDict(headers)
else:
real_headers = headers # = cast('CIMultiDict[str]', headers)
if content_type is not None and "charset" in content_type:
raise ValueError("charset must not be in content_type "
"argument")
if text is not None:
if hdrs.CONTENT_TYPE in real_headers:
if content_type or charset:
raise ValueError("passing both Content-Type header and "
"content_type or charset params "
"is forbidden")
else:
# fast path for filling headers
if not isinstance(text, str):
raise TypeError("text argument must be str (%r)" %
type(text))
if content_type is None:
content_type = 'text/plain'
if charset is None:
charset = 'utf-8'
real_headers[hdrs.CONTENT_TYPE] = (
content_type + '; charset=' + charset)
body = text.encode(charset)
text = None
else:
if hdrs.CONTENT_TYPE in real_headers:
if content_type is not None or charset is not None:
raise ValueError("passing both Content-Type header and "
"content_type or charset params "
"is forbidden")
else:
if content_type is not None:
if charset is not None:
content_type += '; charset=' + charset
real_headers[hdrs.CONTENT_TYPE] = content_type
super().__init__(status=status, reason=reason, headers=real_headers)
if text is not None:
self.text = text
else:
self.body = body
self._compressed_body = None # type: Optional[bytes]
self._zlib_executor_size = zlib_executor_size
self._zlib_executor = zlib_executor
@property
def body(self) -> Optional[Union[bytes, Payload]]:
return self._body
@body.setter
def body(self, body: bytes,
CONTENT_TYPE: istr=hdrs.CONTENT_TYPE,
CONTENT_LENGTH: istr=hdrs.CONTENT_LENGTH) -> None:
if body is None:
self._body = None # type: Optional[bytes]
self._body_payload = False # type: bool
elif isinstance(body, (bytes, bytearray)):
self._body = body
self._body_payload = False
else:
try:
self._body = body = payload.PAYLOAD_REGISTRY.get(body)
except payload.LookupError:
raise ValueError('Unsupported body type %r' % type(body))
self._body_payload = True
headers = self._headers
# set content-length header if needed
if not self._chunked and CONTENT_LENGTH not in headers:
size = body.size
if size is not None:
headers[CONTENT_LENGTH] = str(size)
# set content-type
if CONTENT_TYPE not in headers:
headers[CONTENT_TYPE] = body.content_type
# copy payload headers
if body.headers:
for (key, value) in body.headers.items():
if key not in headers:
headers[key] = value
self._compressed_body = None
@property
def text(self) -> Optional[str]:
if self._body is None:
return None
return self._body.decode(self.charset or 'utf-8')
@text.setter
def text(self, text: str) -> None:
assert text is None or isinstance(text, str), \
"text argument must be str (%r)" % type(text)
if self.content_type == 'application/octet-stream':
self.content_type = 'text/plain'
if self.charset is None:
self.charset = 'utf-8'
self._body = text.encode(self.charset)
self._body_payload = False
self._compressed_body = None
@property
def content_length(self) -> Optional[int]:
if self._chunked:
return None
if hdrs.CONTENT_LENGTH in self._headers:
return super().content_length
if self._compressed_body is not None:
# Return length of the compressed body
return len(self._compressed_body)
elif self._body_payload:
# A payload without content length, or a compressed payload
return None
elif self._body is not None:
return len(self._body)
else:
return 0
@content_length.setter
def content_length(self, value: Optional[int]) -> None:
raise RuntimeError("Content length is set automatically")
async def write_eof(self, data: bytes=b'') -> None:
if self._eof_sent:
return
if self._compressed_body is None:
body = self._body # type: Optional[Union[bytes, Payload]]
else:
body = self._compressed_body
assert not data, "data arg is not supported, got {!r}".format(data)
assert self._req is not None
assert self._payload_writer is not None
if body is not None:
if (self._req._method == hdrs.METH_HEAD or
self._status in [204, 304]):
await super().write_eof()
elif self._body_payload:
payload = cast(Payload, body)
await payload.write(self._payload_writer)
await super().write_eof()
else:
await super().write_eof(cast(bytes, body))
else:
await super().write_eof()
async def _start(self, request: 'BaseRequest') -> AbstractStreamWriter:
if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
if not self._body_payload:
if self._body is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
else:
self._headers[hdrs.CONTENT_LENGTH] = '0'
return await super()._start(request)
def _compress_body(self, zlib_mode: int) -> None:
compressobj = zlib.compressobj(wbits=zlib_mode)
body_in = self._body
assert body_in is not None
self._compressed_body = \
compressobj.compress(body_in) + compressobj.flush()
async def _do_start_compression(self, coding: ContentCoding) -> None:
if self._body_payload or self._chunked:
return await super()._do_start_compression(coding)
if coding != ContentCoding.identity:
# Instead of using _payload_writer.enable_compression,
# compress the whole body
zlib_mode = (16 + zlib.MAX_WBITS
if coding == ContentCoding.gzip else -zlib.MAX_WBITS)
body_in = self._body
assert body_in is not None
if self._zlib_executor_size is not None and \
len(body_in) > self._zlib_executor_size:
await asyncio.get_event_loop().run_in_executor(
self._zlib_executor, self._compress_body, zlib_mode)
else:
self._compress_body(zlib_mode)
body_out = self._compressed_body
assert body_out is not None
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._headers[hdrs.CONTENT_LENGTH] = str(len(body_out))
def json_response(data: Any=sentinel, *,
text: str=None,
body: bytes=None,
status: int=200,
reason: Optional[str]=None,
headers: LooseHeaders=None,
content_type: str='application/json',
dumps: JSONEncoder=json.dumps) -> Response:
if data is not sentinel:
if text or body:
raise ValueError(
"only one of data, text, or body should be specified"
)
else:
text = dumps(data)
return Response(text=text, body=body, status=status, reason=reason,
headers=headers, content_type=content_type)

@ -1,194 +0,0 @@
import abc
import os # noqa
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
overload,
)
import attr
from . import hdrs
from .abc import AbstractView
from .typedefs import PathLike
if TYPE_CHECKING: # pragma: no cover
from .web_urldispatcher import UrlDispatcher
from .web_request import Request
from .web_response import StreamResponse
else:
Request = StreamResponse = UrlDispatcher = None
__all__ = ('AbstractRouteDef', 'RouteDef', 'StaticDef', 'RouteTableDef',
'head', 'options', 'get', 'post', 'patch', 'put', 'delete',
'route', 'view', 'static')
class AbstractRouteDef(abc.ABC):
@abc.abstractmethod
def register(self, router: UrlDispatcher) -> None:
pass # pragma: no cover
_SimpleHandler = Callable[[Request], Awaitable[StreamResponse]]
_HandlerType = Union[Type[AbstractView], _SimpleHandler]
@attr.s(frozen=True, repr=False, slots=True)
class RouteDef(AbstractRouteDef):
method = attr.ib(type=str)
path = attr.ib(type=str)
handler = attr.ib() # type: _HandlerType
kwargs = attr.ib(type=Dict[str, Any])
def __repr__(self) -> str:
info = []
for name, value in sorted(self.kwargs.items()):
info.append(", {}={!r}".format(name, value))
return ("<RouteDef {method} {path} -> {handler.__name__!r}"
"{info}>".format(method=self.method, path=self.path,
handler=self.handler, info=''.join(info)))
def register(self, router: UrlDispatcher) -> None:
if self.method in hdrs.METH_ALL:
reg = getattr(router, 'add_'+self.method.lower())
reg(self.path, self.handler, **self.kwargs)
else:
router.add_route(self.method, self.path, self.handler,
**self.kwargs)
@attr.s(frozen=True, repr=False, slots=True)
class StaticDef(AbstractRouteDef):
prefix = attr.ib(type=str)
path = attr.ib() # type: PathLike
kwargs = attr.ib(type=Dict[str, Any])
def __repr__(self) -> str:
info = []
for name, value in sorted(self.kwargs.items()):
info.append(", {}={!r}".format(name, value))
return ("<StaticDef {prefix} -> {path}"
"{info}>".format(prefix=self.prefix, path=self.path,
info=''.join(info)))
def register(self, router: UrlDispatcher) -> None:
router.add_static(self.prefix, self.path, **self.kwargs)
def route(method: str, path: str, handler: _HandlerType,
**kwargs: Any) -> RouteDef:
return RouteDef(method, path, handler, kwargs)
def head(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_HEAD, path, handler, **kwargs)
def options(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_OPTIONS, path, handler, **kwargs)
def get(path: str, handler: _HandlerType, *, name: Optional[str]=None,
allow_head: bool=True, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_GET, path, handler, name=name,
allow_head=allow_head, **kwargs)
def post(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_POST, path, handler, **kwargs)
def put(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_PUT, path, handler, **kwargs)
def patch(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_PATCH, path, handler, **kwargs)
def delete(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_DELETE, path, handler, **kwargs)
def view(path: str, handler: Type[AbstractView], **kwargs: Any) -> RouteDef:
return route(hdrs.METH_ANY, path, handler, **kwargs)
def static(prefix: str, path: PathLike,
**kwargs: Any) -> StaticDef:
return StaticDef(prefix, path, kwargs)
_Deco = Callable[[_HandlerType], _HandlerType]
class RouteTableDef(Sequence[AbstractRouteDef]):
"""Route definition table"""
def __init__(self) -> None:
self._items = [] # type: List[AbstractRouteDef]
def __repr__(self) -> str:
return "<RouteTableDef count={}>".format(len(self._items))
@overload
def __getitem__(self, index: int) -> AbstractRouteDef: ... # noqa
@overload # noqa
def __getitem__(self, index: slice) -> List[AbstractRouteDef]: ... # noqa
def __getitem__(self, index): # type: ignore # noqa
return self._items[index]
def __iter__(self) -> Iterator[AbstractRouteDef]:
return iter(self._items)
def __len__(self) -> int:
return len(self._items)
def __contains__(self, item: object) -> bool:
return item in self._items
def route(self,
method: str,
path: str,
**kwargs: Any) -> _Deco:
def inner(handler: _HandlerType) -> _HandlerType:
self._items.append(RouteDef(method, path, handler, kwargs))
return handler
return inner
def head(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_HEAD, path, **kwargs)
def get(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_GET, path, **kwargs)
def post(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_POST, path, **kwargs)
def put(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_PUT, path, **kwargs)
def patch(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_PATCH, path, **kwargs)
def delete(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_DELETE, path, **kwargs)
def view(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_ANY, path, **kwargs)
def static(self, prefix: str, path: PathLike,
**kwargs: Any) -> None:
self._items.append(StaticDef(prefix, path, kwargs))

@ -1,337 +0,0 @@
import asyncio
import signal
import socket
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Set
from yarl import URL
from .web_app import Application
from .web_server import Server
try:
from ssl import SSLContext
except ImportError:
SSLContext = object # type: ignore
__all__ = ('BaseSite', 'TCPSite', 'UnixSite', 'NamedPipeSite', 'SockSite',
'BaseRunner', 'AppRunner', 'ServerRunner', 'GracefulExit')
class GracefulExit(SystemExit):
code = 1
def _raise_graceful_exit() -> None:
raise GracefulExit()
class BaseSite(ABC):
__slots__ = ('_runner', '_shutdown_timeout', '_ssl_context', '_backlog',
'_server')
def __init__(self, runner: 'BaseRunner', *,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
backlog: int=128) -> None:
if runner.server is None:
raise RuntimeError("Call runner.setup() before making a site")
self._runner = runner
self._shutdown_timeout = shutdown_timeout
self._ssl_context = ssl_context
self._backlog = backlog
self._server = None # type: Optional[asyncio.AbstractServer]
@property
@abstractmethod
def name(self) -> str:
pass # pragma: no cover
@abstractmethod
async def start(self) -> None:
self._runner._reg_site(self)
async def stop(self) -> None:
self._runner._check_site(self)
if self._server is None:
self._runner._unreg_site(self)
return # not started yet
self._server.close()
# named pipes do not have wait_closed property
if hasattr(self._server, 'wait_closed'):
await self._server.wait_closed()
await self._runner.shutdown()
assert self._runner.server
await self._runner.server.shutdown(self._shutdown_timeout)
self._runner._unreg_site(self)
class TCPSite(BaseSite):
__slots__ = ('_host', '_port', '_reuse_address', '_reuse_port')
def __init__(self, runner: 'BaseRunner',
host: str=None, port: int=None, *,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
backlog: int=128, reuse_address: Optional[bool]=None,
reuse_port: Optional[bool]=None) -> None:
super().__init__(runner, shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context, backlog=backlog)
if host is None:
host = "0.0.0.0"
self._host = host
if port is None:
port = 8443 if self._ssl_context else 8080
self._port = port
self._reuse_address = reuse_address
self._reuse_port = reuse_port
@property
def name(self) -> str:
scheme = 'https' if self._ssl_context else 'http'
return str(URL.build(scheme=scheme, host=self._host, port=self._port))
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server( # type: ignore
server, self._host, self._port,
ssl=self._ssl_context, backlog=self._backlog,
reuse_address=self._reuse_address,
reuse_port=self._reuse_port)
class UnixSite(BaseSite):
__slots__ = ('_path', )
def __init__(self, runner: 'BaseRunner', path: str, *,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
backlog: int=128) -> None:
super().__init__(runner, shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context, backlog=backlog)
self._path = path
@property
def name(self) -> str:
scheme = 'https' if self._ssl_context else 'http'
return '{}://unix:{}:'.format(scheme, self._path)
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_unix_server(
server, self._path,
ssl=self._ssl_context, backlog=self._backlog)
class NamedPipeSite(BaseSite):
__slots__ = ('_path', )
def __init__(self, runner: 'BaseRunner', path: str, *,
shutdown_timeout: float=60.0) -> None:
loop = asyncio.get_event_loop()
if not isinstance(loop, asyncio.ProactorEventLoop): # type: ignore
raise RuntimeError("Named Pipes only available in proactor"
"loop under windows")
super().__init__(runner, shutdown_timeout=shutdown_timeout)
self._path = path
@property
def name(self) -> str:
return self._path
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
_server = await loop.start_serving_pipe( # type: ignore
server, self._path
)
self._server = _server[0]
class SockSite(BaseSite):
__slots__ = ('_sock', '_name')
def __init__(self, runner: 'BaseRunner', sock: socket.socket, *,
shutdown_timeout: float=60.0,
ssl_context: Optional[SSLContext]=None,
backlog: int=128) -> None:
super().__init__(runner, shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context, backlog=backlog)
self._sock = sock
scheme = 'https' if self._ssl_context else 'http'
if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
name = '{}://unix:{}:'.format(scheme, sock.getsockname())
else:
host, port = sock.getsockname()[:2]
name = str(URL.build(scheme=scheme, host=host, port=port))
self._name = name
@property
def name(self) -> str:
return self._name
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server( # type: ignore
server, sock=self._sock,
ssl=self._ssl_context, backlog=self._backlog)
class BaseRunner(ABC):
__slots__ = ('_handle_signals', '_kwargs', '_server', '_sites')
def __init__(self, *, handle_signals: bool=False, **kwargs: Any) -> None:
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server = None # type: Optional[Server]
self._sites = [] # type: List[BaseSite]
@property
def server(self) -> Optional[Server]:
return self._server
@property
def addresses(self) -> List[str]:
ret = [] # type: List[str]
for site in self._sites:
server = site._server
if server is not None:
sockets = server.sockets
if sockets is not None:
for sock in sockets:
ret.append(sock.getsockname())
return ret
@property
def sites(self) -> Set[BaseSite]:
return set(self._sites)
async def setup(self) -> None:
loop = asyncio.get_event_loop()
if self._handle_signals:
try:
loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
except NotImplementedError: # pragma: no cover
# add_signal_handler is not implemented on Windows
pass
self._server = await self._make_server()
@abstractmethod
async def shutdown(self) -> None:
pass # pragma: no cover
async def cleanup(self) -> None:
loop = asyncio.get_event_loop()
if self._server is None:
# no started yet, do nothing
return
# The loop over sites is intentional, an exception on gather()
# leaves self._sites in unpredictable state.
# The loop guaranties that a site is either deleted on success or
# still present on failure
for site in list(self._sites):
await site.stop()
await self._cleanup_server()
self._server = None
if self._handle_signals:
try:
loop.remove_signal_handler(signal.SIGINT)
loop.remove_signal_handler(signal.SIGTERM)
except NotImplementedError: # pragma: no cover
# remove_signal_handler is not implemented on Windows
pass
@abstractmethod
async def _make_server(self) -> Server:
pass # pragma: no cover
@abstractmethod
async def _cleanup_server(self) -> None:
pass # pragma: no cover
def _reg_site(self, site: BaseSite) -> None:
if site in self._sites:
raise RuntimeError("Site {} is already registered in runner {}"
.format(site, self))
self._sites.append(site)
def _check_site(self, site: BaseSite) -> None:
if site not in self._sites:
raise RuntimeError("Site {} is not registered in runner {}"
.format(site, self))
def _unreg_site(self, site: BaseSite) -> None:
if site not in self._sites:
raise RuntimeError("Site {} is not registered in runner {}"
.format(site, self))
self._sites.remove(site)
class ServerRunner(BaseRunner):
"""Low-level web server runner"""
__slots__ = ('_web_server',)
def __init__(self, web_server: Server, *,
handle_signals: bool=False, **kwargs: Any) -> None:
super().__init__(handle_signals=handle_signals, **kwargs)
self._web_server = web_server
async def shutdown(self) -> None:
pass
async def _make_server(self) -> Server:
return self._web_server
async def _cleanup_server(self) -> None:
pass
class AppRunner(BaseRunner):
"""Web Application runner"""
__slots__ = ('_app',)
def __init__(self, app: Application, *,
handle_signals: bool=False, **kwargs: Any) -> None:
super().__init__(handle_signals=handle_signals, **kwargs)
if not isinstance(app, Application):
raise TypeError("The first argument should be web.Application "
"instance, got {!r}".format(app))
self._app = app
@property
def app(self) -> Application:
return self._app
async def shutdown(self) -> None:
await self._app.shutdown()
async def _make_server(self) -> Server:
loop = asyncio.get_event_loop()
self._app._set_loop(loop)
self._app.on_startup.freeze()
await self._app.startup()
self._app.freeze()
return self._app._make_handler(loop=loop, **self._kwargs)
async def _cleanup_server(self) -> None:
await self._app.cleanup()

@ -1,57 +0,0 @@
"""Low level HTTP server."""
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa
from .abc import AbstractStreamWriter
from .helpers import get_running_loop
from .http_parser import RawRequestMessage
from .streams import StreamReader
from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler
from .web_request import BaseRequest
__all__ = ('Server',)
class Server:
def __init__(self,
handler: _RequestHandler,
*,
request_factory: Optional[_RequestFactory]=None,
loop: Optional[asyncio.AbstractEventLoop]=None,
**kwargs: Any) -> None:
self._loop = get_running_loop(loop)
self._connections = {} # type: Dict[RequestHandler, asyncio.Transport]
self._kwargs = kwargs
self.requests_count = 0
self.request_handler = handler
self.request_factory = request_factory or self._make_request
@property
def connections(self) -> List[RequestHandler]:
return list(self._connections.keys())
def connection_made(self, handler: RequestHandler,
transport: asyncio.Transport) -> None:
self._connections[handler] = transport
def connection_lost(self, handler: RequestHandler,
exc: Optional[BaseException]=None) -> None:
if handler in self._connections:
del self._connections[handler]
def _make_request(self, message: RawRequestMessage,
payload: StreamReader,
protocol: RequestHandler,
writer: AbstractStreamWriter,
task: 'asyncio.Task[None]') -> BaseRequest:
return BaseRequest(
message, payload, protocol, writer, task, self._loop)
async def shutdown(self, timeout: Optional[float]=None) -> None:
coros = [conn.shutdown(timeout) for conn in self._connections]
await asyncio.gather(*coros, loop=self._loop)
self._connections.clear()
def __call__(self) -> RequestHandler:
return RequestHandler(self, loop=self._loop, **self._kwargs)

File diff suppressed because it is too large Load Diff

@ -1,328 +1,192 @@
import asyncio
import base64
import binascii
import hashlib
import json
from typing import Any, Iterable, Optional, Tuple
import async_timeout
import attr
from multidict import CIMultiDict
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import call_later, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WS_KEY,
WebSocketError,
WebSocketReader,
WebSocketWriter,
WSMessage,
)
from .http import WSMsgType as WSMsgType
from .http import ws_ext_gen, ws_ext_parse
from .log import ws_logger
from .streams import EofStream, FlowControlDataQueue
from .typedefs import JSONDecoder, JSONEncoder
from .web_exceptions import HTTPBadRequest, HTTPException
from .web_request import BaseRequest
from .web_response import StreamResponse
__all__ = ('WebSocketResponse', 'WebSocketReady', 'WSMsgType',)
import sys
import warnings
from collections import namedtuple
from . import Timeout, hdrs
from ._ws_impl import (CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType,
do_handshake)
from .errors import ClientDisconnectedError, HttpProcessingError
from .web_exceptions import (HTTPBadRequest, HTTPInternalServerError,
HTTPMethodNotAllowed)
from .web_reqrep import StreamResponse
__all__ = ('WebSocketResponse', 'WebSocketReady', 'MsgType', 'WSMsgType',)
PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)
THRESHOLD_CONNLOST_ACCESS = 5
@attr.s(frozen=True, slots=True)
class WebSocketReady:
ok = attr.ib(type=bool)
protocol = attr.ib(type=Optional[str])
# deprecated since 1.0
MsgType = WSMsgType
def __bool__(self) -> bool:
class WebSocketReady(namedtuple('WebSocketReady', 'ok protocol')):
def __bool__(self):
return self.ok
class WebSocketResponse(StreamResponse):
_length_check = False
def __init__(self, *,
timeout: float=10.0, receive_timeout: Optional[float]=None,
autoclose: bool=True, autoping: bool=True,
heartbeat: Optional[float]=None,
protocols: Iterable[str]=(),
compress: bool=True, max_msg_size: int=4*1024*1024) -> None:
timeout=10.0, autoclose=True, autoping=True, protocols=()):
super().__init__(status=101)
self._protocols = protocols
self._ws_protocol = None # type: Optional[str]
self._writer = None # type: Optional[WebSocketWriter]
self._reader = None # type: Optional[FlowControlDataQueue[WSMessage]]
self._protocol = None
self._writer = None
self._reader = None
self._closed = False
self._closing = False
self._conn_lost = 0
self._close_code = None # type: Optional[int]
self._loop = None # type: Optional[asyncio.AbstractEventLoop]
self._waiting = None # type: Optional[asyncio.Future[bool]]
self._exception = None # type: Optional[BaseException]
self._close_code = None
self._loop = None
self._waiting = False
self._exception = None
self._timeout = timeout
self._receive_timeout = receive_timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._compress = compress
self._max_msg_size = max_msg_size
def _cancel_heartbeat(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
if self._heartbeat is not None:
self._heartbeat_cb = call_later(
self._send_heartbeat, self._heartbeat, self._loop)
def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping()) # type: ignore
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received, self._pong_heartbeat, self._loop)
def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._closed = True
self._close_code = 1006
self._exception = asyncio.TimeoutError()
self._req.transport.close()
async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
@asyncio.coroutine
def prepare(self, request):
# make pre-check to don't hide it by do_handshake() exceptions
if self._payload_writer is not None:
return self._payload_writer
protocol, writer = self._pre_start(request)
payload_writer = await super().prepare(request)
assert payload_writer is not None
self._post_start(request, protocol, writer)
await payload_writer.drain()
return payload_writer
def _handshake(self, request: BaseRequest) -> Tuple['CIMultiDict[str]',
str,
bool,
bool]:
headers = request.headers
if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip():
raise HTTPBadRequest(
text=('No WebSocket UPGRADE hdr: {}\n Can '
'"Upgrade" only to "WebSocket".')
.format(headers.get(hdrs.UPGRADE)))
if 'upgrade' not in headers.get(hdrs.CONNECTION, '').lower():
raise HTTPBadRequest(
text='No CONNECTION upgrade hdr: {}'.format(
headers.get(hdrs.CONNECTION)))
# find common sub-protocol between client and server
protocol = None
if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
req_protocols = [str(proto.strip()) for proto in
headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in req_protocols:
if proto in self._protocols:
protocol = proto
break
else:
# No overlap found: Return no protocol as per spec
ws_logger.warning(
'Client protocols %r dont overlap server-known ones %r',
req_protocols, self._protocols)
# check supported version
version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, '')
if version not in ('13', '8', '7'):
raise HTTPBadRequest(
text='Unsupported version: {}'.format(version))
# check client handshake for validity
key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
parser, protocol, writer = self._pre_start(request)
resp_impl = yield from super().prepare(request)
self._post_start(request, parser, protocol, writer)
return resp_impl
def _pre_start(self, request):
try:
if not key or len(base64.b64decode(key)) != 16:
raise HTTPBadRequest(
text='Handshake error: {!r}'.format(key))
except binascii.Error:
raise HTTPBadRequest(
text='Handshake error: {!r}'.format(key)) from None
accept_val = base64.b64encode(
hashlib.sha1(key.encode() + WS_KEY).digest()).decode()
response_headers = CIMultiDict( # type: ignore
{hdrs.UPGRADE: 'websocket',
hdrs.CONNECTION: 'upgrade',
hdrs.SEC_WEBSOCKET_ACCEPT: accept_val})
notakeover = False
compress = 0
if self._compress:
extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
# Server side always get return with no exception.
# If something happened, just drop compress extension
compress, notakeover = ws_ext_parse(extensions, isserver=True)
if compress:
enabledext = ws_ext_gen(compress=compress, isserver=True,
server_notakeover=notakeover)
response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
if protocol:
response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
return (response_headers, # type: ignore
protocol,
compress,
notakeover)
def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]:
self._loop = request._loop
headers, protocol, compress, notakeover = self._handshake(
request)
self._reset_heartbeat()
self.set_status(101)
self.headers.update(headers)
status, headers, parser, writer, protocol = do_handshake(
request.method, request.headers, request.transport,
self._protocols)
except HttpProcessingError as err:
if err.code == 405:
raise HTTPMethodNotAllowed(
request.method, [hdrs.METH_GET], body=b'')
elif err.code == 400:
raise HTTPBadRequest(text=err.message, headers=err.headers)
else: # pragma: no cover
raise HTTPInternalServerError() from err
if self.status != status:
self.set_status(status)
for k, v in headers:
self.headers[k] = v
self.force_close()
self._compress = compress
transport = request._protocol.transport
assert transport is not None
writer = WebSocketWriter(request._protocol,
transport,
compress=compress,
notakeover=notakeover)
return protocol, writer
def _post_start(self, request: BaseRequest,
protocol: str, writer: WebSocketWriter) -> None:
self._ws_protocol = protocol
return parser, protocol, writer
def _post_start(self, request, parser, protocol, writer):
self._reader = request._reader.set_parser(parser)
self._writer = writer
loop = self._loop
assert loop is not None
self._reader = FlowControlDataQueue(
request._protocol, limit=2 ** 16, loop=loop)
request.protocol.set_parser(WebSocketReader(
self._reader, self._max_msg_size, compress=self._compress))
# disable HTTP keepalive for WebSocket
request.protocol.keep_alive(False)
def can_prepare(self, request: BaseRequest) -> WebSocketReady:
self._protocol = protocol
self._loop = request.app.loop
def start(self, request):
warnings.warn('use .prepare(request) instead', DeprecationWarning)
# make pre-check to don't hide it by do_handshake() exceptions
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
parser, protocol, writer = self._pre_start(request)
resp_impl = super().start(request)
self._post_start(request, parser, protocol, writer)
return resp_impl
def can_prepare(self, request):
if self._writer is not None:
raise RuntimeError('Already started')
try:
_, protocol, _, _ = self._handshake(request)
except HTTPException:
_, _, _, _, protocol = do_handshake(
request.method, request.headers, request.transport,
self._protocols)
except HttpProcessingError:
return WebSocketReady(False, None)
else:
return WebSocketReady(True, protocol)
def can_start(self, request):
warnings.warn('use .can_prepare(request) instead', DeprecationWarning)
return self.can_prepare(request)
@property
def closed(self) -> bool:
def closed(self):
return self._closed
@property
def close_code(self) -> Optional[int]:
def close_code(self):
return self._close_code
@property
def ws_protocol(self) -> Optional[str]:
return self._ws_protocol
@property
def compress(self) -> bool:
return self._compress
def protocol(self):
return self._protocol
def exception(self) -> Optional[BaseException]:
def exception(self):
return self._exception
async def ping(self, message: bytes=b'') -> None:
def ping(self, message='b'):
if self._writer is None:
raise RuntimeError('Call .prepare() first')
await self._writer.ping(message)
if self._closed:
raise RuntimeError('websocket connection is closing')
self._writer.ping(message)
async def pong(self, message: bytes=b'') -> None:
def pong(self, message='b'):
# unsolicited pong
if self._writer is None:
raise RuntimeError('Call .prepare() first')
await self._writer.pong(message)
if self._closed:
raise RuntimeError('websocket connection is closing')
self._writer.pong(message)
async def send_str(self, data: str, compress: Optional[bool]=None) -> None:
def send_str(self, data):
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if self._closed:
raise RuntimeError('websocket connection is closing')
if not isinstance(data, str):
raise TypeError('data argument must be str (%r)' % type(data))
await self._writer.send(data, binary=False, compress=compress)
self._writer.send(data, binary=False)
async def send_bytes(self, data: bytes,
compress: Optional[bool]=None) -> None:
def send_bytes(self, data):
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if self._closed:
raise RuntimeError('websocket connection is closing')
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)' %
type(data))
await self._writer.send(data, binary=True, compress=compress)
self._writer.send(data, binary=True)
async def send_json(self, data: Any, compress: Optional[bool]=None, *,
dumps: JSONEncoder=json.dumps) -> None:
await self.send_str(dumps(data), compress=compress)
def send_json(self, data, *, dumps=json.dumps):
self.send_str(dumps(data))
async def write_eof(self) -> None: # type: ignore
@asyncio.coroutine
def write_eof(self):
if self._eof_sent:
return
if self._payload_writer is None:
if self._resp_impl is None:
raise RuntimeError("Response has not been started")
await self.close()
yield from self.close()
self._eof_sent = True
async def close(self, *, code: int=1000, message: bytes=b'') -> bool:
@asyncio.coroutine
def close(self, *, code=1000, message=b''):
if self._writer is None:
raise RuntimeError('Call .prepare() first')
self._cancel_heartbeat()
reader = self._reader
assert reader is not None
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._waiting
if not self._closed:
self._closed = True
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
await writer.drain()
self._writer.close(code, message)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
raise
@ -334,22 +198,23 @@ class WebSocketResponse(StreamResponse):
if self._closing:
return True
reader = self._reader
assert reader is not None
try:
with async_timeout.timeout(self._timeout, loop=self._loop):
msg = await reader.read()
except asyncio.CancelledError:
self._close_code = 1006
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
return True
if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
return True
begin = self._loop.time()
while self._loop.time() - begin < self._timeout:
try:
with Timeout(timeout=self._timeout,
loop=self._loop):
msg = yield from self._reader.read()
except asyncio.CancelledError:
self._close_code = 1006
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
return True
if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
return True
self._close_code = 1006
self._exception = asyncio.TimeoutError()
@ -357,100 +222,99 @@ class WebSocketResponse(StreamResponse):
else:
return False
async def receive(self, timeout: Optional[float]=None) -> WSMessage:
@asyncio.coroutine
def receive(self):
if self._reader is None:
raise RuntimeError('Call .prepare() first')
if self._waiting:
raise RuntimeError('Concurrent call to receive() is not allowed')
loop = self._loop
assert loop is not None
while True:
if self._waiting is not None:
raise RuntimeError(
'Concurrent call to receive() is not allowed')
if self._closed:
self._conn_lost += 1
if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
raise RuntimeError('WebSocket connection is closed.')
return WS_CLOSED_MESSAGE
elif self._closing:
return WS_CLOSING_MESSAGE
self._waiting = True
try:
while True:
if self._closed:
self._conn_lost += 1
if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
raise RuntimeError('WebSocket connection is closed.')
return CLOSED_MESSAGE
try:
self._waiting = loop.create_future()
try:
with async_timeout.timeout(
timeout or self._receive_timeout, loop=self._loop):
msg = await self._reader.read()
self._reset_heartbeat()
finally:
waiter = self._waiting
set_result(waiter, True)
self._waiting = None
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
raise
except EofStream:
self._close_code = 1000
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type == WSMsgType.CLOSING:
self._closing = True
elif msg.type == WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type == WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float]=None) -> str:
msg = await self.receive(timeout)
msg = yield from self._reader.read()
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except WebSocketError as exc:
self._close_code = exc.code
yield from self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except ClientDisconnectedError:
self._closed = True
self._close_code = 1006
return WSMessage(WSMsgType.CLOSE, None, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
yield from self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
yield from self.close()
return msg
if msg.type == WSMsgType.PING and self._autoping:
self.pong(msg.data)
elif msg.type == WSMsgType.PONG and self._autoping:
continue
else:
return msg
finally:
self._waiting = False
@asyncio.coroutine
def receive_msg(self):
warnings.warn(
'receive_msg() coroutine is deprecated. use receive() instead',
DeprecationWarning)
return (yield from self.receive())
@asyncio.coroutine
def receive_str(self):
msg = yield from self.receive()
if msg.type != WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not WSMsgType.TEXT".format(
msg.type, msg.data))
"Received message {}:{!r} is not str".format(msg.type,
msg.data))
return msg.data
async def receive_bytes(self, *, timeout: Optional[float]=None) -> bytes:
msg = await self.receive(timeout)
@asyncio.coroutine
def receive_bytes(self):
msg = yield from self.receive()
if msg.type != WSMsgType.BINARY:
raise TypeError(
"Received message {}:{!r} is not bytes".format(msg.type,
msg.data))
return msg.data
async def receive_json(self, *, loads: JSONDecoder=json.loads,
timeout: Optional[float]=None) -> Any:
data = await self.receive_str(timeout=timeout)
@asyncio.coroutine
def receive_json(self, *, loads=json.loads):
data = yield from self.receive_str()
return loads(data)
async def write(self, data: bytes) -> None:
def write(self, data):
raise RuntimeError("Cannot call .write() for websocket")
def __aiter__(self) -> 'WebSocketResponse':
return self
if PY_35:
def __aiter__(self):
return self
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
async def __anext__(self) -> WSMessage:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE,
WSMsgType.CLOSING,
WSMsgType.CLOSED):
raise StopAsyncIteration # NOQA
return msg
@asyncio.coroutine
def __anext__(self):
msg = yield from self.receive()
if msg.type == WSMsgType.CLOSE:
raise StopAsyncIteration # NOQA
return msg

@ -4,30 +4,15 @@ import asyncio
import os
import re
import signal
import ssl
import sys
from types import FrameType
from typing import Any, Awaitable, Callable, Optional, Union # noqa
import gunicorn.workers.base as base
from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
from gunicorn.workers import base
from aiohttp import web
from aiohttp.helpers import AccessLogger, ensure_future
from .helpers import set_result
from .web_app import Application
from .web_log import AccessLogger
try:
import ssl
SSLContext = ssl.SSLContext # noqa
except ImportError: # pragma: no cover
ssl = None # type: ignore
SSLContext = object # type: ignore
__all__ = ('GunicornWebWorker',
'GunicornUVLoopWebWorker',
'GunicornTokioWebWorker')
__all__ = ('GunicornWebWorker', 'GunicornUVLoopWebWorker')
class GunicornWebWorker(base.Worker):
@ -35,14 +20,13 @@ class GunicornWebWorker(base.Worker):
DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover
def __init__(self, *args, **kw): # pragma: no cover
super().__init__(*args, **kw)
self._task = None # type: Optional[asyncio.Task[None]]
self.servers = {}
self.exit_code = 0
self._notify_waiter = None # type: Optional[asyncio.Future[bool]]
def init_process(self) -> None:
def init_process(self):
# create new event_loop after fork
asyncio.get_event_loop().close()
@ -51,56 +35,71 @@ class GunicornWebWorker(base.Worker):
super().init_process()
def run(self) -> None:
self._task = self.loop.create_task(self._run())
def run(self):
self.loop.run_until_complete(self.wsgi.startup())
self._runner = ensure_future(self._run(), loop=self.loop)
try: # ignore all finalization problems
self.loop.run_until_complete(self._task)
except Exception:
self.log.exception("Exception in gunicorn worker")
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
try:
self.loop.run_until_complete(self._runner)
finally:
self.loop.close()
sys.exit(self.exit_code)
async def _run(self) -> None:
if isinstance(self.wsgi, Application):
app = self.wsgi
elif asyncio.iscoroutinefunction(self.wsgi):
app = await self.wsgi()
else:
raise RuntimeError("wsgi app should be either Application or "
"async function returning Application, got {}"
.format(self.wsgi))
access_log = self.log.access_log if self.cfg.accesslog else None
runner = web.AppRunner(app,
logger=self.log,
keepalive_timeout=self.cfg.keepalive,
access_log=access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
await runner.setup()
def make_handler(self, app):
return app.make_handler(
logger=self.log,
slow_request_timeout=self.cfg.timeout,
keepalive_timeout=self.cfg.keepalive,
access_log=self.log.access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
@asyncio.coroutine
def close(self):
if self.servers:
servers = self.servers
self.servers = None
# stop accepting connections
for server, handler in servers.items():
self.log.info("Stopping server: %s, connections: %s",
self.pid, len(handler.connections))
server.close()
yield from server.wait_closed()
# send on_shutdown event
yield from self.wsgi.shutdown()
# stop alive connections
tasks = [
handler.finish_connections(
timeout=self.cfg.graceful_timeout / 100 * 95)
for handler in servers.values()]
yield from asyncio.gather(*tasks, loop=self.loop)
# cleanup application
yield from self.wsgi.cleanup()
@asyncio.coroutine
def _run(self):
ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
runner = runner
assert runner is not None
server = runner.server
assert server is not None
for sock in self.sockets:
site = web.SockSite(
runner, sock, ssl_context=ctx,
shutdown_timeout=self.cfg.graceful_timeout / 100 * 95)
await site.start()
handler = self.make_handler(self.wsgi)
srv = yield from self.loop.create_server(handler, sock=sock.sock,
ssl=ctx)
self.servers[srv] = handler
# If our parent changed then we shut down.
pid = os.getpid()
try:
while self.alive: # type: ignore
while self.alive:
self.notify()
cnt = server.requests_count
cnt = sum(handler.requests_count
for handler in self.servers.values())
if self.cfg.max_requests and cnt > self.cfg.max_requests:
self.alive = False
self.log.info("Max requests, shutting down: %s", self)
@ -109,32 +108,14 @@ class GunicornWebWorker(base.Worker):
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
else:
await self._wait_next_notify()
yield from asyncio.sleep(1.0, loop=self.loop)
except BaseException:
pass
await runner.cleanup()
def _wait_next_notify(self) -> 'asyncio.Future[bool]':
self._notify_waiter_done()
yield from self.close()
loop = self.loop
assert loop is not None
self._notify_waiter = waiter = loop.create_future()
self.loop.call_later(1.0, self._notify_waiter_done, waiter)
return waiter
def _notify_waiter_done(self, waiter: 'asyncio.Future[bool]'=None) -> None:
if waiter is None:
waiter = self._notify_waiter
if waiter is not None:
set_result(waiter, True)
if waiter is self._notify_waiter:
self._notify_waiter = None
def init_signals(self) -> None:
def init_signals(self):
# Set up signals through the event loop API.
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
@ -160,30 +141,19 @@ class GunicornWebWorker(base.Worker):
signal.siginterrupt(signal.SIGTERM, False)
signal.siginterrupt(signal.SIGUSR1, False)
def handle_quit(self, sig: int, frame: FrameType) -> None:
def handle_quit(self, sig, frame):
self.alive = False
# worker_int callback
self.cfg.worker_int(self)
# wakeup closing process
self._notify_waiter_done()
def handle_abort(self, sig: int, frame: FrameType) -> None:
def handle_abort(self, sig, frame):
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)
@staticmethod
def _create_ssl_context(cfg: Any) -> 'SSLContext':
def _create_ssl_context(cfg):
""" Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details.
"""
if ssl is None: # pragma: no cover
raise RuntimeError('SSL is not supported.')
ctx = ssl.SSLContext(cfg.ssl_version)
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
ctx.verify_mode = cfg.cert_reqs
@ -193,7 +163,7 @@ class GunicornWebWorker(base.Worker):
ctx.set_ciphers(cfg.ciphers)
return ctx
def _get_valid_log_format(self, source_format: str) -> str:
def _get_valid_log_format(self, source_format):
if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
return self.DEFAULT_AIOHTTP_LOG_FORMAT
elif re.search(r'%\([^\)]+\)', source_format):
@ -201,7 +171,7 @@ class GunicornWebWorker(base.Worker):
"Gunicorn's style options in form of `%(name)s` are not "
"supported for the log formatting. Please use aiohttp's "
"format specification to configure access log formatting: "
"http://docs.aiohttp.org/en/stable/logging.html"
"http://aiohttp.readthedocs.io/en/stable/logging.html"
"#format-specification"
)
else:
@ -210,7 +180,7 @@ class GunicornWebWorker(base.Worker):
class GunicornUVLoopWebWorker(GunicornWebWorker):
def init_process(self) -> None:
def init_process(self):
import uvloop
# Close any existing event loop before setting a
@ -223,20 +193,3 @@ class GunicornUVLoopWebWorker(GunicornWebWorker):
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
super().init_process()
class GunicornTokioWebWorker(GunicornWebWorker):
def init_process(self) -> None: # pragma: no cover
import tokio
# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
# Setup tokio policy, so that every
# asyncio.get_event_loop() will create an instance
# of tokio event loop.
asyncio.set_event_loop_policy(tokio.EventLoopPolicy())
super().init_process()

@ -0,0 +1,235 @@
"""wsgi server.
TODO:
* proxy protocol
* x-forward security
* wsgi file support (os.sendfile)
"""
import asyncio
import inspect
import io
import os
import socket
import sys
from urllib.parse import urlsplit
import aiohttp
from aiohttp import hdrs, server
__all__ = ('WSGIServerHttpProtocol',)
class WSGIServerHttpProtocol(server.ServerHttpProtocol):
"""HTTP Server that implements the Python WSGI protocol.
It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently
depends on 'readpayload' constructor parameter. If readpayload is set to
True, wsgi server reads all incoming data into BytesIO object and
sends it as 'wsgi.input' environ var. If readpayload is set to false
'wsgi.input' is a StreamReader and application should read incoming
data with "yield from environ['wsgi.input'].read()". It defaults to False.
"""
SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '')
def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw):
super().__init__(*args, **kw)
self.wsgi = app
self.is_ssl = is_ssl
self.readpayload = readpayload
def create_wsgi_response(self, message):
return WsgiResponse(self.writer, message)
def create_wsgi_environ(self, message, payload):
uri_parts = urlsplit(message.path)
environ = {
'wsgi.input': payload,
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'wsgi.file_wrapper': FileWrapper,
'SERVER_SOFTWARE': aiohttp.HttpMessage.SERVER_SOFTWARE,
'REQUEST_METHOD': message.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': message.path,
'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version
}
script_name = self.SCRIPT_NAME
for hdr_name, hdr_value in message.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'SCRIPT_NAME':
script_name = hdr_value
elif hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
url_scheme = environ.get('HTTP_X_FORWARDED_PROTO')
if url_scheme is None:
url_scheme = 'https' if self.is_ssl else 'http'
environ['wsgi.url_scheme'] = url_scheme
# authors should be aware that REMOTE_HOST and REMOTE_ADDR
# may not qualify the remote addr
# also SERVER_PORT variable MUST be set to the TCP/IP port number on
# which this request is received from the client.
# http://www.ietf.org/rfc/rfc3875
family = self.transport.get_extra_info('socket').family
if family in (socket.AF_INET, socket.AF_INET6):
peername = self.transport.get_extra_info('peername')
environ['REMOTE_ADDR'] = peername[0]
environ['REMOTE_PORT'] = str(peername[1])
http_host = message.headers.get("HOST", None)
if http_host:
hostport = http_host.split(":")
environ['SERVER_NAME'] = hostport[0]
if len(hostport) > 1:
environ['SERVER_PORT'] = str(hostport[1])
else:
environ['SERVER_PORT'] = '80'
else:
# SERVER_NAME should be set to value of Host header, but this
# header is not required. In this case we shoud set it to local
# address of socket
sockname = self.transport.get_extra_info('sockname')
environ['SERVER_NAME'] = sockname[0]
environ['SERVER_PORT'] = str(sockname[1])
else:
# We are behind reverse proxy, so get all vars from headers
for header in ('REMOTE_ADDR', 'REMOTE_PORT',
'SERVER_NAME', 'SERVER_PORT'):
environ[header] = message.headers.get(header, '')
path_info = uri_parts.path
if script_name:
path_info = path_info.split(script_name, 1)[-1]
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = script_name
environ['async.reader'] = self.reader
environ['async.writer'] = self.writer
return environ
@asyncio.coroutine
def handle_request(self, message, payload):
"""Handle a single HTTP request"""
now = self._loop.time()
if self.readpayload:
wsgiinput = io.BytesIO()
wsgiinput.write((yield from payload.read()))
wsgiinput.seek(0)
payload = wsgiinput
environ = self.create_wsgi_environ(message, payload)
response = self.create_wsgi_response(message)
riter = self.wsgi(environ, response.start_response)
if isinstance(riter, asyncio.Future) or inspect.isgenerator(riter):
riter = yield from riter
resp = response.response
try:
for item in riter:
if isinstance(item, asyncio.Future):
item = yield from item
yield from resp.write(item)
yield from resp.write_eof()
finally:
if hasattr(riter, 'close'):
riter.close()
if resp.keep_alive():
self.keep_alive(True)
self.log_access(
message, environ, response.response, self._loop.time() - now)
class FileWrapper:
"""Custom file wrapper."""
def __init__(self, fobj, chunk_size=8192):
self.fobj = fobj
self.chunk_size = chunk_size
if hasattr(fobj, 'close'):
self.close = fobj.close
def __iter__(self):
return self
def __next__(self):
data = self.fobj.read(self.chunk_size)
if data:
return data
raise StopIteration
class WsgiResponse:
"""Implementation of start_response() callable as specified by PEP 3333"""
status = None
HOP_HEADERS = {
hdrs.CONNECTION,
hdrs.KEEP_ALIVE,
hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.UPGRADE,
}
def __init__(self, writer, message):
self.writer = writer
self.message = message
def start_response(self, status, headers, exc_info=None):
if exc_info:
try:
if self.status:
raise exc_info[1]
finally:
exc_info = None
status_code = int(status.split(' ', 1)[0])
self.status = status
resp = self.response = aiohttp.Response(
self.writer, status_code,
self.message.version, self.message.should_close)
resp.HOP_HEADERS = self.HOP_HEADERS
for name, value in headers:
resp.add_header(name, value)
if resp.has_chunked_hdr:
resp.enable_chunked_encoding()
# send headers immediately for websocket connection
if status_code == 101 and resp.upgrade and resp.websocket:
resp.send_headers()
else:
resp._send_headers = True
return self.response.write

@ -999,7 +999,6 @@ class Bot(BotBase, discord.Client):
"""
pass
class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from
:class:`discord.AutoShardedClient` instead.

@ -0,0 +1,30 @@
WebSockets
==========
``websockets`` is a library for developing WebSocket servers_ and clients_ in
Python. It implements `RFC 6455`_ with a focus on correctness and simplicity.
It passes the `Autobahn Testsuite`_.
Built on top of Python's asynchronous I/O support introduced in `PEP 3156`_,
it provides an API based on coroutines, making it easy to write highly
concurrent applications.
Installation is as simple as ``pip install websockets``. It requires Python ≥
3.4 or Python 3.3 with the ``asyncio`` module, which is available with ``pip
install asyncio``.
Documentation is available on `Read the Docs`_.
Bug reports, patches and suggestions welcome! Just open an issue_ or send a
`pull request`_.
.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py
.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py
.. _RFC 6455: http://tools.ietf.org/html/rfc6455
.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst
.. _PEP 3156: http://www.python.org/dev/peps/pep-3156/
.. _Read the Docs: https://websockets.readthedocs.io/
.. _issue: https://github.com/aaugustin/websockets/issues/new
.. _pull request: https://github.com/aaugustin/websockets/compare/

@ -0,0 +1,52 @@
Metadata-Version: 2.0
Name: websockets
Version: 3.4
Summary: An implementation of the WebSocket Protocol (RFC 6455)
Home-page: https://github.com/aaugustin/websockets
Author: Aymeric Augustin
Author-email: aymeric.augustin@m4x.org
License: BSD
Platform: UNKNOWN
Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: 3.6
Requires-Dist: asyncio; python_version=="3.3"
WebSockets
==========
``websockets`` is a library for developing WebSocket servers_ and clients_ in
Python. It implements `RFC 6455`_ with a focus on correctness and simplicity.
It passes the `Autobahn Testsuite`_.
Built on top of Python's asynchronous I/O support introduced in `PEP 3156`_,
it provides an API based on coroutines, making it easy to write highly
concurrent applications.
Installation is as simple as ``pip install websockets``. It requires Python ≥
3.4 or Python 3.3 with the ``asyncio`` module, which is available with ``pip
install asyncio``.
Documentation is available on `Read the Docs`_.
Bug reports, patches and suggestions welcome! Just open an issue_ or send a
`pull request`_.
.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py
.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py
.. _RFC 6455: http://tools.ietf.org/html/rfc6455
.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst
.. _PEP 3156: http://www.python.org/dev/peps/pep-3156/
.. _Read the Docs: https://websockets.readthedocs.io/
.. _issue: https://github.com/aaugustin/websockets/issues/new
.. _pull request: https://github.com/aaugustin/websockets/compare/

@ -0,0 +1,56 @@
websockets-3.4.dist-info/DESCRIPTION.rst,sha256=Xfv_W4k7cI7wsQ7_GkxVUWox9Fj9Kbicr9vGzVaN-Rk,1264
websockets-3.4.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
websockets-3.4.dist-info/METADATA,sha256=GrwKnapJ4AN94EN2RS1qyayejMXWV3K2wuQHzdueBAQ,2105
websockets-3.4.dist-info/RECORD,,
websockets-3.4.dist-info/WHEEL,sha256=4fkP9V5fUlnPlEu2h0nt7u0cPpJipYsJO32nXNevnFk,106
websockets-3.4.dist-info/metadata.json,sha256=6s9m0aXolRcZ2UodIViSHcTIqQuL6gJPJir-ZZl4stA,1000
websockets-3.4.dist-info/top_level.txt,sha256=EcCngZER7Li9SDhzH7kpxRCzPMkWrtR98pBqycxQDZc,27
websockets/__init__.py,sha256=dtzzVSk5GjukX2OhwZhjfYhQBWTQCzYYoruzoivf8f4,383
websockets/__pycache__/__init__.cpython-36.pyc,,
websockets/__pycache__/client.cpython-36.pyc,,
websockets/__pycache__/compatibility.cpython-36.pyc,,
websockets/__pycache__/exceptions.cpython-36.pyc,,
websockets/__pycache__/framing.cpython-36.pyc,,
websockets/__pycache__/handshake.cpython-36.pyc,,
websockets/__pycache__/http.cpython-36.pyc,,
websockets/__pycache__/protocol.cpython-36.pyc,,
websockets/__pycache__/server.cpython-36.pyc,,
websockets/__pycache__/test_client_server.cpython-36.pyc,,
websockets/__pycache__/test_framing.cpython-36.pyc,,
websockets/__pycache__/test_handshake.cpython-36.pyc,,
websockets/__pycache__/test_http.cpython-36.pyc,,
websockets/__pycache__/test_protocol.cpython-36.pyc,,
websockets/__pycache__/test_speedups.cpython-36.pyc,,
websockets/__pycache__/test_uri.cpython-36.pyc,,
websockets/__pycache__/test_utils.cpython-36.pyc,,
websockets/__pycache__/uri.cpython-36.pyc,,
websockets/__pycache__/utils.cpython-36.pyc,,
websockets/__pycache__/version.cpython-36.pyc,,
websockets/client.py,sha256=wr9k4IjcLoN8xLSCIJF06A4j43Ji-u0xvQdd6lDGuAY,8330
websockets/compatibility.py,sha256=DfAlt7rra9enxR0TO0qrAnFLumTGmqhMyhANw3cv9IA,1567
websockets/exceptions.py,sha256=zOGatgy9IM75Ofzsh2qKI2NTfnR9yD5ap3sy4KN_tKk,2309
websockets/framing.py,sha256=lBcIIC82qgNgtcsBSBI4d8C2ydjw_ErI_DgvhieSSUM,6083
websockets/handshake.py,sha256=tqZJV9kTRjm1LOX5ZjFVOFEjE1zs1W2YaIgKu6vjDJ4,4310
websockets/http.py,sha256=G7LHgfMqjs5dYaLwxxzpa2rR14wWcz_qu4qW8SLsA9Q,6421
websockets/protocol.py,sha256=jUku43eyUR8qYl6NCpGprK3SRGtoc982jP4jv1HQIm0,26694
websockets/py35/__init__.py,sha256=yYRQ_76YWF-La7UK4uZwOcPEIp4CE3RwXZ99BUM7qX4,152
websockets/py35/__pycache__/__init__.cpython-36.pyc,,
websockets/py35/__pycache__/client.cpython-36.pyc,,
websockets/py35/__pycache__/client_server.cpython-36.pyc,,
websockets/py35/__pycache__/server.cpython-36.pyc,,
websockets/py35/client.py,sha256=nveWJBe17u584UR6NjYl-bU52zmsZe-p2HMIcw9ooCk,568
websockets/py35/client_server.py,sha256=Y0SWqtC4qhy02lr3Tn-HftM1pan9vMKHEgMu-kOcY5A,1341
websockets/py35/server.py,sha256=vmxXu77VNvOPUC3opeMJRcUG6Vs8bEzLhRSGUiYclYQ,589
websockets/server.py,sha256=jpT0HGvZag17QP373wTGpR6-ZcHmZGuaLo2ih_yB7v0,20691
websockets/speedups.cp36-win_amd64.pyd,sha256=FwNTiJryIYA5m_Nmkmu1koT-BM8jFFayeK1MO-_f9_k,11264
websockets/test_client_server.py,sha256=3g87ybCp6rkXWt2Dvi00_xKOlT5dOy_Q9O6PyGAmGDc,23413
websockets/test_framing.py,sha256=zWqDNwDAF6qHwV9j9rFiEePxu7s5bVFqHaDZFLi3MOY,5371
websockets/test_handshake.py,sha256=TKyzYrqRURklO3dEH1oq8fFn64repe2U415ARLJ5DoE,4223
websockets/test_http.py,sha256=DnmICytpKL1iswq_TA9SR1VO3Md3T5asUEP9i__-8OM,4663
websockets/test_protocol.py,sha256=xh-AMIMJrgRrquzHS5mry9ZGoa_yQcXXUnx2inhnwM0,30556
websockets/test_speedups.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
websockets/test_uri.py,sha256=JDs1Orabj9WWrykp6nUKhyPIe7w-uRmAoGgQKPntDBs,932
websockets/test_utils.py,sha256=eNGhH9RsTOxcEhUbJRs04du_r3EIGmXYPuTSOwE5uDY,1441
websockets/uri.py,sha256=80xWOSHYFcnFjGJ5tyxkUVFkha2wp2FFepSxbXYts9w,1517
websockets/utils.py,sha256=Qt5rpZiL_6u4KgS5WFSXTrFhmDbZzydA6LGBM6DacMY,277
websockets/version.py,sha256=mCvUBEXtlj9qxGDKLmrm21r_qKAXR2K2pfPYy6W7sCg,16

@ -1,5 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.33.6)
Generator: bdist_wheel (0.29.0)
Root-Is-Purelib: false
Tag: cp36-cp36m-win_amd64

@ -0,0 +1 @@
{"classifiers": ["Development Status :: 5 - Production/Stable", "Environment :: Web Environment", "Intended Audience :: Developers", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.3", "Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6"], "extensions": {"python.details": {"contacts": [{"email": "aymeric.augustin@m4x.org", "name": "Aymeric Augustin", "role": "author"}], "document_names": {"description": "DESCRIPTION.rst"}, "project_urls": {"Home": "https://github.com/aaugustin/websockets"}}}, "extras": [], "generator": "bdist_wheel (0.29.0)", "license": "BSD", "metadata_version": "2.0", "name": "websockets", "run_requires": [{"environment": "python_version==\"3.3\"", "requires": ["asyncio"]}], "summary": "An implementation of the WebSocket Protocol (RFC 6455)", "version": "3.4"}

@ -1,25 +0,0 @@
Copyright (c) 2013-2019 Aymeric Augustin and contributors.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of websockets nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -1,167 +0,0 @@
Metadata-Version: 2.1
Name: websockets
Version: 8.1
Summary: An implementation of the WebSocket Protocol (RFC 6455 & 7692)
Home-page: https://github.com/aaugustin/websockets
Author: Aymeric Augustin
Author-email: aymeric.augustin@m4x.org
License: BSD
Platform: UNKNOWN
Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Web Environment
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Requires-Python: >=3.6.1
.. image:: logo/horizontal.svg
:width: 480px
:alt: websockets
|rtd| |pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |circleci| |codecov|
.. |rtd| image:: https://readthedocs.org/projects/websockets/badge/?version=latest
:target: https://websockets.readthedocs.io/
.. |pypi-v| image:: https://img.shields.io/pypi/v/websockets.svg
:target: https://pypi.python.org/pypi/websockets
.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg
:target: https://pypi.python.org/pypi/websockets
.. |pypi-l| image:: https://img.shields.io/pypi/l/websockets.svg
:target: https://pypi.python.org/pypi/websockets
.. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg
:target: https://pypi.python.org/pypi/websockets
.. |circleci| image:: https://img.shields.io/circleci/project/github/aaugustin/websockets.svg
:target: https://circleci.com/gh/aaugustin/websockets
.. |codecov| image:: https://codecov.io/gh/aaugustin/websockets/branch/master/graph/badge.svg
:target: https://codecov.io/gh/aaugustin/websockets
What is ``websockets``?
-----------------------
``websockets`` is a library for building WebSocket servers_ and clients_ in
Python with a focus on correctness and simplicity.
.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py
.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py
Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it
provides an elegant coroutine-based API.
`Documentation is available on Read the Docs. <https://websockets.readthedocs.io/>`_
Here's how a client sends and receives messages:
.. copy-pasted because GitHub doesn't support the include directive
.. code:: python
#!/usr/bin/env python
import asyncio
import websockets
async def hello(uri):
async with websockets.connect(uri) as websocket:
await websocket.send("Hello world!")
await websocket.recv()
asyncio.get_event_loop().run_until_complete(
hello('ws://localhost:8765'))
And here's an echo server:
.. code:: python
#!/usr/bin/env python
import asyncio
import websockets
async def echo(websocket, path):
async for message in websocket:
await websocket.send(message)
asyncio.get_event_loop().run_until_complete(
websockets.serve(echo, 'localhost', 8765))
asyncio.get_event_loop().run_forever()
Does that look good?
`Get started with the tutorial! <https://websockets.readthedocs.io/en/stable/intro.html>`_
Why should I use ``websockets``?
--------------------------------
The development of ``websockets`` is shaped by four principles:
1. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and
``await ws.send(msg)``; ``websockets`` takes care of managing connections
so you can focus on your application.
2. **Robustness**: ``websockets`` is built for production; for example it was
the only library to `handle backpressure correctly`_ before the issue
became widely known in the Python community.
3. **Quality**: ``websockets`` is heavily tested. Continuous integration fails
under 100% branch coverage. Also it passes the industry-standard `Autobahn
Testsuite`_.
4. **Performance**: memory use is configurable. An extension written in C
accelerates expensive operations. It's pre-compiled for Linux, macOS and
Windows and packaged in the wheel format for each system and Python version.
Documentation is a first class concern in the project. Head over to `Read the
Docs`_ and see for yourself.
.. _Read the Docs: https://websockets.readthedocs.io/
.. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers
.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst
Why shouldn't I use ``websockets``?
-----------------------------------
* If you prefer callbacks over coroutines: ``websockets`` was created to
provide the best coroutine-based API to manage WebSocket connections in
Python. Pick another library for a callback-based API.
* If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims
at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol
and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP
is minimal — just enough for a HTTP health check.
* If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which
only works on Python 3. ``websockets`` requires Python ≥ 3.6.1.
What else?
----------
Bug reports, patches and suggestions are welcome!
To report a security vulnerability, please use the `Tidelift security
contact`_. Tidelift will coordinate the fix and disclosure.
.. _Tidelift security contact: https://tidelift.com/security
For anything else, please open an issue_ or send a `pull request`_.
.. _issue: https://github.com/aaugustin/websockets/issues/new
.. _pull request: https://github.com/aaugustin/websockets/compare/
Participants must uphold the `Contributor Covenant code of conduct`_.
.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/master/CODE_OF_CONDUCT.md
``websockets`` is released under the `BSD license`_.
.. _BSD license: https://github.com/aaugustin/websockets/blob/master/LICENSE

@ -1,44 +0,0 @@
websockets-8.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
websockets-8.1.dist-info/LICENSE,sha256=ioiWDA1qqLOPqm7mFFl9vxjg6iK6IIhO890x00sqbQk,1536
websockets-8.1.dist-info/METADATA,sha256=Ybp0HGQ7VBJZMyP-ZNors83guDrgCvSa8mdIypZsX3U,6200
websockets-8.1.dist-info/RECORD,,
websockets-8.1.dist-info/WHEEL,sha256=uQaeujkjkt7SlmOZGXO6onhwBPrzw2WTI2otbCZzdNI,106
websockets-8.1.dist-info/top_level.txt,sha256=KtfDkU36u2JojZLmzHTIQUQ39q6RdOuHhI-rvY5J-FM,33
websockets/__init__.py,sha256=358xBAefIahVJA8sDJbA-zKGd-tVUqKoUJdIDjc8cjQ,1314
websockets/__main__.py,sha256=U3euVZyLJmWYySdr4gP0TmvqiPbfGgKHf1N25E2y8z4,6420
websockets/__pycache__/__init__.cpython-36.pyc,,
websockets/__pycache__/__main__.cpython-36.pyc,,
websockets/__pycache__/auth.cpython-36.pyc,,
websockets/__pycache__/client.cpython-36.pyc,,
websockets/__pycache__/exceptions.cpython-36.pyc,,
websockets/__pycache__/framing.cpython-36.pyc,,
websockets/__pycache__/handshake.cpython-36.pyc,,
websockets/__pycache__/headers.cpython-36.pyc,,
websockets/__pycache__/http.cpython-36.pyc,,
websockets/__pycache__/protocol.cpython-36.pyc,,
websockets/__pycache__/server.cpython-36.pyc,,
websockets/__pycache__/typing.cpython-36.pyc,,
websockets/__pycache__/uri.cpython-36.pyc,,
websockets/__pycache__/utils.cpython-36.pyc,,
websockets/__pycache__/version.cpython-36.pyc,,
websockets/auth.py,sha256=R11zxLlNsK-RHlsdryL3pXgv5fWwJACJnqBpwHESGF8,5414
websockets/client.py,sha256=3xRLPV7pDusiPHKV6eF4YUFazKyGX7jFwhMk-Vy4Tw8,21215
websockets/exceptions.py,sha256=CLfGkck8Qt-_ICMfj0E8NaGgEotgp_6XGoDaxn_zthY,8824
websockets/extensions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
websockets/extensions/__pycache__/__init__.cpython-36.pyc,,
websockets/extensions/__pycache__/base.cpython-36.pyc,,
websockets/extensions/__pycache__/permessage_deflate.cpython-36.pyc,,
websockets/extensions/base.py,sha256=pefUdApzb7Z0zV_AGlMVlpE3Bj3QcTFplBfApaoptgs,2784
websockets/extensions/permessage_deflate.py,sha256=FT_kqFSHBp0PnI988kS99U1CeNaYI41DQMxOjo53Mxk,21730
websockets/framing.py,sha256=DgeRjpeom5ehX2gb838DUy1g2dhPBVMWncmZQyojhnA,10244
websockets/handshake.py,sha256=GoSlNCymiSZdkm4QWzMjPx-20QF1-w7KhfQ7ODxPEw8,6152
websockets/headers.py,sha256=guRd8T_ppb7OEPXCDH8D4p3hsZcksNPF3YUNp9afXQM,15069
websockets/http.py,sha256=-ay6psjIla3YU-a_0UW3GcXXTBeDqiyanhEg-l_nW88,11826
websockets/protocol.py,sha256=I8p-T-aDsz5m9OZ181I0r8R0jiz1EBcF2EhWFHHYTEw,55030
websockets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
websockets/server.py,sha256=WZ86dfWq3uLDF2f8o0wG1glsJyYGaMKwwdzLA9CZL3s,38186
websockets/speedups.cp36-win_amd64.pyd,sha256=-xmeYRCiN2buHVC9KafFBWTmHkRz6rurGhSibHUlrXk,12288
websockets/typing.py,sha256=R-OvOXKtMXeu49At4J5hAg9s3Da3iT77cM9luddy33k,1305
websockets/uri.py,sha256=q1VLxJE3q9xEWJI1FfVEqeHtSE1-YOob2sAmvSVXHyU,2244
websockets/utils.py,sha256=wFZSnRYfkz94nN1L9vITaZShs56_I5cOEXeE6Iez7oY,376
websockets/version.py,sha256=6nmfcHh_ydnzhw9ujCQU5lRV4fyb84UgtEV7Ud6QGL8,16

@ -1,55 +1,17 @@
# This relies on each of the submodules having an __all__ variable.
from .auth import * # noqa
from .client import * # noqa
from .exceptions import * # noqa
from .protocol import * # noqa
from .server import * # noqa
from .typing import * # noqa
from .uri import * # noqa
from .version import version as __version__ # noqa
from .client import *
from .exceptions import *
from .protocol import *
from .server import *
from .uri import *
from .version import version as __version__ # noqa
__all__ = [
"AbortHandshake",
"basic_auth_protocol_factory",
"BasicAuthWebSocketServerProtocol",
"connect",
"ConnectionClosed",
"ConnectionClosedError",
"ConnectionClosedOK",
"Data",
"DuplicateParameter",
"ExtensionHeader",
"ExtensionParameter",
"InvalidHandshake",
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
"InvalidMessage",
"InvalidOrigin",
"InvalidParameterName",
"InvalidParameterValue",
"InvalidState",
"InvalidStatusCode",
"InvalidUpgrade",
"InvalidURI",
"NegotiationError",
"Origin",
"parse_uri",
"PayloadTooBig",
"ProtocolError",
"RedirectHandshake",
"SecurityError",
"serve",
"Subprotocol",
"unix_connect",
"unix_serve",
"WebSocketClientProtocol",
"WebSocketCommonProtocol",
"WebSocketException",
"WebSocketProtocolError",
"WebSocketServer",
"WebSocketServerProtocol",
"WebSocketURI",
]
__all__ = (
client.__all__ +
exceptions.__all__ +
protocol.__all__ +
server.__all__ +
uri.__all__
)

@ -1,160 +0,0 @@
"""
:mod:`websockets.auth` provides HTTP Basic Authentication according to
:rfc:`7235` and :rfc:`7617`.
"""
import functools
import http
from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union
from .exceptions import InvalidHeader
from .headers import build_www_authenticate_basic, parse_authorization_basic
from .http import Headers
from .server import HTTPResponse, WebSocketServerProtocol
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
Credentials = Tuple[str, str]
def is_credentials(value: Any) -> bool:
try:
username, password = value
except (TypeError, ValueError):
return False
else:
return isinstance(username, str) and isinstance(password, str)
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
"""
WebSocket server protocol that enforces HTTP Basic Auth.
"""
def __init__(
self,
*args: Any,
realm: str,
check_credentials: Callable[[str, str], Awaitable[bool]],
**kwargs: Any,
) -> None:
self.realm = realm
self.check_credentials = check_credentials
super().__init__(*args, **kwargs)
async def process_request(
self, path: str, request_headers: Headers
) -> Optional[HTTPResponse]:
"""
Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed.
If authentication succeeds, the username of the authenticated user is
stored in the ``username`` attribute.
"""
try:
authorization = request_headers["Authorization"]
except KeyError:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Missing credentials\n",
)
try:
username, password = parse_authorization_basic(authorization)
except InvalidHeader:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Unsupported credentials\n",
)
if not await self.check_credentials(username, password):
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Invalid credentials\n",
)
self.username = username
return await super().process_request(path, request_headers)
def basic_auth_protocol_factory(
realm: str,
credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None,
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
create_protocol: Type[
BasicAuthWebSocketServerProtocol
] = BasicAuthWebSocketServerProtocol,
) -> Callable[[Any], BasicAuthWebSocketServerProtocol]:
"""
Protocol factory that enforces HTTP Basic Auth.
``basic_auth_protocol_factory`` is designed to integrate with
:func:`~websockets.server.serve` like this::
websockets.serve(
...,
create_protocol=websockets.basic_auth_protocol_factory(
realm="my dev server",
credentials=("hello", "iloveyou"),
)
)
``realm`` indicates the scope of protection. It should contain only ASCII
characters because the encoding of non-ASCII characters is undefined.
Refer to section 2.2 of :rfc:`7235` for details.
``credentials`` defines hard coded authorized credentials. It can be a
``(username, password)`` pair or a list of such pairs.
``check_credentials`` defines a coroutine that checks whether credentials
are authorized. This coroutine receives ``username`` and ``password``
arguments and returns a :class:`bool`.
One of ``credentials`` or ``check_credentials`` must be provided but not
both.
By default, ``basic_auth_protocol_factory`` creates a factory for building
:class:`BasicAuthWebSocketServerProtocol` instances. You can override this
with the ``create_protocol`` parameter.
:param realm: scope of protection
:param credentials: hard coded credentials
:param check_credentials: coroutine that verifies credentials
:raises TypeError: if the credentials argument has the wrong type
"""
if (credentials is None) == (check_credentials is None):
raise TypeError("provide either credentials or check_credentials")
if credentials is not None:
if is_credentials(credentials):
async def check_credentials(username: str, password: str) -> bool:
return (username, password) == credentials
elif isinstance(credentials, Iterable):
credentials_list = list(credentials)
if all(is_credentials(item) for item in credentials_list):
credentials_dict = dict(credentials_list)
async def check_credentials(username: str, password: str) -> bool:
return credentials_dict.get(username) == password
else:
raise TypeError(f"invalid credentials argument: {credentials}")
else:
raise TypeError(f"invalid credentials argument: {credentials}")
return functools.partial(
create_protocol, realm=realm, check_credentials=check_credentials
)

@ -1,584 +1,231 @@
"""
:mod:`websockets.client` defines the WebSocket client APIs.
The :mod:`websockets.client` module defines a simple WebSocket client API.
"""
import asyncio
import collections.abc
import functools
import logging
import warnings
from types import TracebackType
from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast
from .exceptions import (
InvalidHandshake,
InvalidHeader,
InvalidMessage,
InvalidStatusCode,
NegotiationError,
RedirectHandshake,
SecurityError,
)
from .extensions.base import ClientExtensionFactory, Extension
from .extensions.permessage_deflate import ClientPerMessageDeflateFactory
from .handshake import build_request, check_response
from .headers import (
build_authorization_basic,
build_extension,
build_subprotocol,
parse_extension,
parse_subprotocol,
)
from .http import USER_AGENT, Headers, HeadersLike, read_response
from .protocol import WebSocketCommonProtocol
from .typing import ExtensionHeader, Origin, Subprotocol
from .uri import WebSocketURI, parse_uri
from .exceptions import InvalidHandshake, InvalidMessage, InvalidStatusCode
from .handshake import build_request, check_response
from .http import USER_AGENT, build_headers, read_response
from .protocol import CONNECTING, OPEN, WebSocketCommonProtocol
from .uri import parse_uri
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
logger = logging.getLogger(__name__)
__all__ = ['connect', 'WebSocketClientProtocol']
class WebSocketClientProtocol(WebSocketCommonProtocol):
"""
:class:`~asyncio.Protocol` subclass implementing a WebSocket client.
Complete WebSocket client implementation as an :class:`asyncio.Protocol`.
This class inherits most of its methods from
:class:`~websockets.protocol.WebSocketCommonProtocol`.
"""
is_client = True
side = "client"
def __init__(
self,
*,
origin: Optional[Origin] = None,
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLike] = None,
**kwargs: Any,
) -> None:
self.origin = origin
self.available_extensions = extensions
self.available_subprotocols = subprotocols
self.extra_headers = extra_headers
super().__init__(**kwargs)
def write_http_request(self, path: str, headers: Headers) -> None:
state = CONNECTING
@asyncio.coroutine
def write_http_request(self, path, headers):
"""
Write request line and headers to the HTTP request.
"""
self.path = path
self.request_headers = headers
logger.debug("%s > GET %s HTTP/1.1", self.side, path)
logger.debug("%s > %r", self.side, headers)
self.request_headers = build_headers(headers)
self.raw_request_headers = headers
# Since the path and headers only contain ASCII characters,
# we can keep this simple.
request = f"GET {path} HTTP/1.1\r\n"
request += str(headers)
request = ['GET {path} HTTP/1.1'.format(path=path)]
request.extend('{}: {}'.format(k, v) for k, v in headers)
request.append('\r\n')
request = '\r\n'.join(request).encode()
self.transport.write(request.encode())
self.writer.write(request)
async def read_http_response(self) -> Tuple[int, Headers]:
@asyncio.coroutine
def read_http_response(self):
"""
Read status line and headers from the HTTP response.
If the response contains a body, it may be read from ``self.reader``
after this coroutine returns.
Raise :exc:`~websockets.exceptions.InvalidMessage` if the HTTP message
is malformed or isn't an HTTP/1.1 GET request.
:raises ~websockets.exceptions.InvalidMessage: if the HTTP message is
malformed or isn't an HTTP/1.1 GET response
Don't attempt to read the response body because WebSocket handshake
responses don't have one. If the response contains a body, it may be
read from ``self.reader`` after this coroutine returns.
"""
try:
status_code, reason, headers = await read_response(self.reader)
except Exception as exc:
raise InvalidMessage("did not receive a valid HTTP response") from exc
status_code, headers = yield from read_response(self.reader)
except ValueError as exc:
raise InvalidMessage("Malformed HTTP message") from exc
logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason)
logger.debug("%s < %r", self.side, headers)
self.response_headers = headers
self.response_headers = build_headers(headers)
self.raw_response_headers = headers
return status_code, self.response_headers
@staticmethod
def process_extensions(
headers: Headers,
available_extensions: Optional[Sequence[ClientExtensionFactory]],
) -> List[Extension]:
def process_subprotocol(self, get_header, subprotocols=None):
"""
Handle the Sec-WebSocket-Extensions HTTP response header.
Check that each extension is supported, as well as its parameters.
Return the list of accepted extensions.
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
connection.
:rfc:`6455` leaves the rules up to the specification of each
:extension.
To provide this level of flexibility, for each extension accepted by
the server, we check for a match with each extension available in the
client configuration. If no match is found, an exception is raised.
If several variants of the same extension are accepted by the server,
it may be configured severel times, which won't make sense in general.
Extensions must implement their own requirements. For this purpose,
the list of previously accepted extensions is provided.
Other requirements, for example related to mandatory extensions or the
order of extensions, may be implemented by overriding this method.
Handle the Sec-WebSocket-Protocol HTTP header.
"""
accepted_extensions: List[Extension] = []
header_values = headers.get_all("Sec-WebSocket-Extensions")
if header_values:
if available_extensions is None:
raise InvalidHandshake("no extensions supported")
parsed_header_values: List[ExtensionHeader] = sum(
[parse_extension(header_value) for header_value in header_values], []
)
for name, response_params in parsed_header_values:
for extension_factory in available_extensions:
# Skip non-matching extensions based on their name.
if extension_factory.name != name:
continue
# Skip non-matching extensions based on their params.
try:
extension = extension_factory.process_response_params(
response_params, accepted_extensions
)
except NegotiationError:
continue
# Add matching extension to the final list.
accepted_extensions.append(extension)
# Break out of the loop once we have a match.
break
# If we didn't break from the loop, no extension in our list
# matched what the server sent. Fail the connection.
else:
raise NegotiationError(
f"Unsupported extension: "
f"name = {name}, params = {response_params}"
)
return accepted_extensions
@staticmethod
def process_subprotocol(
headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
) -> Optional[Subprotocol]:
subprotocol = get_header('Sec-WebSocket-Protocol')
if subprotocol:
if subprotocols is None or subprotocol not in subprotocols:
raise InvalidHandshake(
"Unknown subprotocol: {}".format(subprotocol))
return subprotocol
@asyncio.coroutine
def handshake(self, wsuri,
origin=None, subprotocols=None, extra_headers=None):
"""
Handle the Sec-WebSocket-Protocol HTTP response header.
Check that it contains exactly one supported subprotocol.
Return the selected subprotocol.
"""
subprotocol: Optional[Subprotocol] = None
header_values = headers.get_all("Sec-WebSocket-Protocol")
if header_values:
if available_subprotocols is None:
raise InvalidHandshake("no subprotocols supported")
parsed_header_values: Sequence[Subprotocol] = sum(
[parse_subprotocol(header_value) for header_value in header_values], []
)
if len(parsed_header_values) > 1:
subprotocols = ", ".join(parsed_header_values)
raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
subprotocol = parsed_header_values[0]
if subprotocol not in available_subprotocols:
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
Perform the client side of the opening handshake.
return subprotocol
If provided, ``origin`` sets the Origin HTTP header.
async def handshake(
self,
wsuri: WebSocketURI,
origin: Optional[Origin] = None,
available_extensions: Optional[Sequence[ClientExtensionFactory]] = None,
available_subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLike] = None,
) -> None:
"""
Perform the client side of the opening handshake.
If provided, ``subprotocols`` is a list of supported subprotocols in
order of decreasing preference.
:param origin: sets the Origin HTTP header
:param available_extensions: list of supported extensions in the order
in which they should be used
:param available_subprotocols: list of supported subprotocols in order
of decreasing preference
:param extra_headers: sets additional HTTP request headers; it must be
a :class:`~websockets.http.Headers` instance, a
:class:`~collections.abc.Mapping`, or an iterable of ``(name,
value)`` pairs
:raises ~websockets.exceptions.InvalidHandshake: if the handshake
fails
If provided, ``extra_headers`` sets additional HTTP request headers.
It must be a mapping or an iterable of (name, value) pairs.
"""
request_headers = Headers()
headers = []
set_header = lambda k, v: headers.append((k, v))
if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover
request_headers["Host"] = wsuri.host
if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover
set_header('Host', wsuri.host)
else:
request_headers["Host"] = f"{wsuri.host}:{wsuri.port}"
if wsuri.user_info:
request_headers["Authorization"] = build_authorization_basic(
*wsuri.user_info
)
set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port))
if origin is not None:
request_headers["Origin"] = origin
key = build_request(request_headers)
if available_extensions is not None:
extensions_header = build_extension(
[
(extension_factory.name, extension_factory.get_request_params())
for extension_factory in available_extensions
]
)
request_headers["Sec-WebSocket-Extensions"] = extensions_header
if available_subprotocols is not None:
protocol_header = build_subprotocol(available_subprotocols)
request_headers["Sec-WebSocket-Protocol"] = protocol_header
set_header('Origin', origin)
if subprotocols is not None:
set_header('Sec-WebSocket-Protocol', ', '.join(subprotocols))
if extra_headers is not None:
if isinstance(extra_headers, Headers):
extra_headers = extra_headers.raw_items()
elif isinstance(extra_headers, collections.abc.Mapping):
if isinstance(extra_headers, collections.abc.Mapping):
extra_headers = extra_headers.items()
for name, value in extra_headers:
request_headers[name] = value
set_header(name, value)
set_header('User-Agent', USER_AGENT)
request_headers.setdefault("User-Agent", USER_AGENT)
key = build_request(set_header)
self.write_http_request(wsuri.resource_name, request_headers)
yield from self.write_http_request(wsuri.resource_name, headers)
status_code, response_headers = await self.read_http_response()
if status_code in (301, 302, 303, 307, 308):
if "Location" not in response_headers:
raise InvalidHeader("Location")
raise RedirectHandshake(response_headers["Location"])
elif status_code != 101:
raise InvalidStatusCode(status_code)
status_code, headers = yield from self.read_http_response()
get_header = lambda k: headers.get(k, '')
check_response(response_headers, key)
if status_code != 101:
raise InvalidStatusCode(status_code)
self.extensions = self.process_extensions(
response_headers, available_extensions
)
check_response(get_header, key)
self.subprotocol = self.process_subprotocol(
response_headers, available_subprotocols
)
self.subprotocol = self.process_subprotocol(get_header, subprotocols)
self.connection_open()
assert self.state == CONNECTING
self.state = OPEN
self.opening_handshake.set_result(True)
class Connect:
@asyncio.coroutine
def connect(uri, *,
create_protocol=None,
timeout=10, max_size=2 ** 20, max_queue=2 ** 5,
read_limit=2 ** 16, write_limit=2 ** 16,
loop=None, legacy_recv=False, klass=None,
origin=None, subprotocols=None, extra_headers=None,
**kwds):
"""
Connect to the WebSocket server at the given ``uri``.
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
can then be used to send and receive messages.
This coroutine connects to a WebSocket server at a given ``uri``.
:func:`connect` can also be used as a asynchronous context manager. In
that case, the connection is closed when exiting the context.
It yields a :class:`WebSocketClientProtocol` which can then be used to
send and receive messages.
:func:`connect` is a wrapper around the event loop's
:meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments
are passed to :meth:`~asyncio.loop.create_connection`.
:meth:`~asyncio.BaseEventLoop.create_connection` method. Unknown keyword
arguments are passed to :meth:`~asyncio.BaseEventLoop.create_connection`.
For example, you can set the ``ssl`` keyword argument to a
:class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to
a ``wss://`` URI, if this argument isn't provided explicitly,
:func:`ssl.create_default_context` is called to create a context.
You can connect to a different host and port from those found in ``uri``
by setting ``host`` and ``port`` keyword arguments. This only changes the
destination of the TCP connection. The host name from ``uri`` is still
used in the TLS handshake for secure connections and in the ``Host`` HTTP
header.
The ``create_protocol`` parameter allows customizing the
:class:`~asyncio.Protocol` that manages the connection. It should be a
callable or class accepting the same arguments as
:class:`WebSocketClientProtocol` and returning an instance of
:class:`WebSocketClientProtocol` or a subclass. It defaults to
:class:`WebSocketClientProtocol`.
a ``wss://`` URI, if this argument isn't provided explicitly, it's set to
``True``, which means Python's default :class:`~ssl.SSLContext` is used.
The behavior of the ``timeout``, ``max_size``, and ``max_queue``,
``read_limit``, and ``write_limit`` optional arguments is described in the
documentation of :class:`~websockets.protocol.WebSocketCommonProtocol`.
The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is
described in :class:`~websockets.protocol.WebSocketCommonProtocol`.
The ``create_protocol`` parameter allows customizing the asyncio protocol
that manages the connection. It should be a callable or class accepting
the same arguments as :class:`WebSocketClientProtocol` and returning a
:class:`WebSocketClientProtocol` instance. It defaults to
:class:`WebSocketClientProtocol`.
:func:`connect` also accepts the following optional arguments:
* ``compression`` is a shortcut to configure compression extensions;
by default it enables the "permessage-deflate" extension; set it to
``None`` to disable compression
* ``origin`` sets the Origin HTTP header
* ``extensions`` is a list of supported extensions in order of
decreasing preference
* ``subprotocols`` is a list of supported subprotocols in order of
decreasing preference
* ``extra_headers`` sets additional HTTP request headers; it can be a
:class:`~websockets.http.Headers` instance, a
:class:`~collections.abc.Mapping`, or an iterable of ``(name, value)``
pairs
:raises ~websockets.uri.InvalidURI: if ``uri`` is invalid
:raises ~websockets.handshake.InvalidHandshake: if the opening handshake
fails
"""
MAX_REDIRECTS_ALLOWED = 10
def __init__(
self,
uri: str,
*,
path: Optional[str] = None,
create_protocol: Optional[Type[WebSocketClientProtocol]] = None,
ping_interval: float = 20,
ping_timeout: float = 20,
close_timeout: Optional[float] = None,
max_size: int = 2 ** 20,
max_queue: int = 2 ** 5,
read_limit: int = 2 ** 16,
write_limit: int = 2 ** 16,
loop: Optional[asyncio.AbstractEventLoop] = None,
legacy_recv: bool = False,
klass: Optional[Type[WebSocketClientProtocol]] = None,
timeout: Optional[float] = None,
compression: Optional[str] = "deflate",
origin: Optional[Origin] = None,
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
extra_headers: Optional[HeadersLike] = None,
**kwargs: Any,
) -> None:
# Backwards compatibility: close_timeout used to be called timeout.
if timeout is None:
timeout = 10
else:
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
# If both are specified, timeout is ignored.
if close_timeout is None:
close_timeout = timeout
# Backwards compatibility: create_protocol used to be called klass.
if klass is None:
klass = WebSocketClientProtocol
else:
warnings.warn("rename klass to create_protocol", DeprecationWarning)
# If both are specified, klass is ignored.
if create_protocol is None:
create_protocol = klass
if loop is None:
loop = asyncio.get_event_loop()
wsuri = parse_uri(uri)
if wsuri.secure:
kwargs.setdefault("ssl", True)
elif kwargs.get("ssl") is not None:
raise ValueError(
"connect() received a ssl argument for a ws:// URI, "
"use a wss:// URI to enable TLS"
)
if compression == "deflate":
if extensions is None:
extensions = []
if not any(
extension_factory.name == ClientPerMessageDeflateFactory.name
for extension_factory in extensions
):
extensions = list(extensions) + [
ClientPerMessageDeflateFactory(client_max_window_bits=True)
]
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
factory = functools.partial(
create_protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_size=max_size,
max_queue=max_queue,
read_limit=read_limit,
write_limit=write_limit,
loop=loop,
host=wsuri.host,
port=wsuri.port,
secure=wsuri.secure,
legacy_recv=legacy_recv,
origin=origin,
extensions=extensions,
subprotocols=subprotocols,
extra_headers=extra_headers,
)
if path is None:
host: Optional[str]
port: Optional[int]
if kwargs.get("sock") is None:
host, port = wsuri.host, wsuri.port
else:
# If sock is given, host and port shouldn't be specified.
host, port = None, None
# If host and port are given, override values from the URI.
host = kwargs.pop("host", host)
port = kwargs.pop("port", port)
create_connection = functools.partial(
loop.create_connection, factory, host, port, **kwargs
)
else:
create_connection = functools.partial(
loop.create_unix_connection, factory, path, **kwargs
)
# This is a coroutine function.
self._create_connection = create_connection
self._wsuri = wsuri
def handle_redirect(self, uri: str) -> None:
# Update the state of this instance to connect to a new URI.
old_wsuri = self._wsuri
new_wsuri = parse_uri(uri)
# Forbid TLS downgrade.
if old_wsuri.secure and not new_wsuri.secure:
raise SecurityError("redirect from WSS to WS")
same_origin = (
old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
)
# Rewrite the host and port arguments for cross-origin redirects.
# This preserves connection overrides with the host and port
# arguments if the redirect points to the same host and port.
if not same_origin:
# Replace the host and port argument passed to the protocol factory.
factory = self._create_connection.args[0]
factory = functools.partial(
factory.func,
*factory.args,
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
)
# Replace the host and port argument passed to create_connection.
self._create_connection = functools.partial(
self._create_connection.func,
*(factory, new_wsuri.host, new_wsuri.port),
**self._create_connection.keywords,
)
# Set the new WebSocket URI. This suffices for same-origin redirects.
self._wsuri = new_wsuri
# async with connect(...)
async def __aenter__(self) -> WebSocketClientProtocol:
return await self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.ws_client.close()
# await connect(...)
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
# Create a suitable iterator by calling __await__ on a coroutine.
return self.__await_impl__().__await__()
async def __await_impl__(self) -> WebSocketClientProtocol:
for redirects in range(self.MAX_REDIRECTS_ALLOWED):
transport, protocol = await self._create_connection()
# https://github.com/python/typeshed/pull/2756
transport = cast(asyncio.Transport, transport)
protocol = cast(WebSocketClientProtocol, protocol)
try:
try:
await protocol.handshake(
self._wsuri,
origin=protocol.origin,
available_extensions=protocol.available_extensions,
available_subprotocols=protocol.available_subprotocols,
extra_headers=protocol.extra_headers,
)
except Exception:
protocol.fail_connection()
await protocol.wait_closed()
raise
else:
self.ws_client = protocol
return protocol
except RedirectHandshake as exc:
self.handle_redirect(exc.uri)
else:
raise SecurityError("too many redirects")
# yield from connect(...)
__iter__ = __await__
connect = Connect
def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect:
"""
Similar to :func:`connect`, but for connecting to a Unix socket.
This function calls the event loop's
:meth:`~asyncio.loop.create_unix_connection` method.
It is only available on Unix.
* ``extra_headers`` sets additional HTTP request headers it can be a
mapping or an iterable of (name, value) pairs
It's mainly useful for debugging servers listening on Unix sockets.
:func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is
invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening
handshake fails.
:param path: file system path to the Unix socket
:param uri: WebSocket URI
On Python 3.5, :func:`connect` can be used as a asynchronous context
manager. In that case, the connection is closed when exiting the context.
"""
return connect(uri=uri, path=path, **kwargs)
if loop is None:
loop = asyncio.get_event_loop()
# Backwards-compatibility: create_protocol used to be called klass.
# In the unlikely event that both are specified, klass is ignored.
if create_protocol is None:
create_protocol = klass
if create_protocol is None:
create_protocol = WebSocketClientProtocol
wsuri = parse_uri(uri)
if wsuri.secure:
kwds.setdefault('ssl', True)
elif kwds.get('ssl') is not None:
raise ValueError("connect() received a SSL context for a ws:// URI. "
"Use a wss:// URI to enable TLS.")
factory = lambda: create_protocol(
host=wsuri.host, port=wsuri.port, secure=wsuri.secure,
timeout=timeout, max_size=max_size, max_queue=max_queue,
read_limit=read_limit, write_limit=write_limit,
loop=loop, legacy_recv=legacy_recv,
)
transport, protocol = yield from loop.create_connection(
factory, wsuri.host, wsuri.port, **kwds)
try:
yield from protocol.handshake(
wsuri, origin=origin, subprotocols=subprotocols,
extra_headers=extra_headers)
except Exception:
yield from protocol.close_connection(force=True)
raise
return protocol
try:
from .py35.client import Connect
except (SyntaxError, ImportError): # pragma: no cover
pass
else:
Connect.__wrapped__ = connect
# Copy over docstring to support building documentation on Python 3.5.
Connect.__doc__ = connect.__doc__
connect = Connect

@ -0,0 +1,48 @@
import asyncio
import http
# Replace with BaseEventLoop.create_task when dropping Python < 3.4.2.
try: # pragma: no cover
asyncio_ensure_future = asyncio.ensure_future # Python ≥ 3.5
except AttributeError: # pragma: no cover
asyncio_ensure_future = asyncio.async # Python < 3.5
try: # pragma: no cover
# Python ≥ 3.5
SWITCHING_PROTOCOLS = http.HTTPStatus.SWITCHING_PROTOCOLS
OK = http.HTTPStatus.OK
BAD_REQUEST = http.HTTPStatus.BAD_REQUEST
UNAUTHORIZED = http.HTTPStatus.UNAUTHORIZED
FORBIDDEN = http.HTTPStatus.FORBIDDEN
INTERNAL_SERVER_ERROR = http.HTTPStatus.INTERNAL_SERVER_ERROR
SERVICE_UNAVAILABLE = http.HTTPStatus.SERVICE_UNAVAILABLE
except AttributeError: # pragma: no cover
# Python < 3.5
class SWITCHING_PROTOCOLS:
value = 101
phrase = "Switching Protocols"
class OK:
value = 200
phrase = "OK"
class BAD_REQUEST:
value = 400
phrase = "Bad Request"
class UNAUTHORIZED:
value = 401
phrase = "Unauthorized"
class FORBIDDEN:
value = 403
phrase = "Forbidden"
class INTERNAL_SERVER_ERROR:
value = 500
phrase = "Internal Server Error"
class SERVICE_UNAVAILABLE:
value = 503
phrase = "Service Unavailable"

@ -1,366 +1,95 @@
"""
:mod:`websockets.exceptions` defines the following exception hierarchy:
* :exc:`WebSocketException`
* :exc:`ConnectionClosed`
* :exc:`ConnectionClosedError`
* :exc:`ConnectionClosedOK`
* :exc:`InvalidHandshake`
* :exc:`SecurityError`
* :exc:`InvalidMessage`
* :exc:`InvalidHeader`
* :exc:`InvalidHeaderFormat`
* :exc:`InvalidHeaderValue`
* :exc:`InvalidOrigin`
* :exc:`InvalidUpgrade`
* :exc:`InvalidStatusCode`
* :exc:`NegotiationError`
* :exc:`DuplicateParameter`
* :exc:`InvalidParameterName`
* :exc:`InvalidParameterValue`
* :exc:`AbortHandshake`
* :exc:`RedirectHandshake`
* :exc:`InvalidState`
* :exc:`InvalidURI`
* :exc:`PayloadTooBig`
* :exc:`ProtocolError`
"""
import http
from typing import Optional
from .http import Headers, HeadersLike
__all__ = [
"WebSocketException",
"ConnectionClosed",
"ConnectionClosedError",
"ConnectionClosedOK",
"InvalidHandshake",
"SecurityError",
"InvalidMessage",
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
"InvalidOrigin",
"InvalidUpgrade",
"InvalidStatusCode",
"NegotiationError",
"DuplicateParameter",
"InvalidParameterName",
"InvalidParameterValue",
"AbortHandshake",
"RedirectHandshake",
"InvalidState",
"InvalidURI",
"PayloadTooBig",
"ProtocolError",
"WebSocketProtocolError",
'AbortHandshake', 'InvalidHandshake', 'InvalidMessage', 'InvalidOrigin',
'InvalidState', 'InvalidStatusCode', 'InvalidURI', 'ConnectionClosed',
'PayloadTooBig', 'WebSocketProtocolError',
]
class WebSocketException(Exception):
"""
Base class for all exceptions defined by :mod:`websockets`.
"""
CLOSE_CODES = {
1000: "OK",
1001: "going away",
1002: "protocol error",
1003: "unsupported type",
# 1004 is reserved
1005: "no status code [internal]",
1006: "connection closed abnormally [internal]",
1007: "invalid data",
1008: "policy violation",
1009: "message too big",
1010: "extension required",
1011: "unexpected error",
1015: "TLS failure [internal]",
}
def format_close(code: int, reason: str) -> str:
"""
Display a human-readable version of the close code and reason.
"""
if 3000 <= code < 4000:
explanation = "registered"
elif 4000 <= code < 5000:
explanation = "private use"
else:
explanation = CLOSE_CODES.get(code, "unknown")
result = f"code = {code} ({explanation}), "
if reason:
result += f"reason = {reason}"
else:
result += "no reason"
return result
class ConnectionClosed(WebSocketException):
"""
Raised when trying to interact with a closed connection.
Provides the connection close code and reason in its ``code`` and
``reason`` attributes respectively.
"""
def __init__(self, code: int, reason: str) -> None:
self.code = code
self.reason = reason
super().__init__(format_close(code, reason))
class ConnectionClosedError(ConnectionClosed):
"""
Like :exc:`ConnectionClosed`, when the connection terminated with an error.
This means the close code is different from 1000 (OK) and 1001 (going away).
"""
def __init__(self, code: int, reason: str) -> None:
assert code != 1000 and code != 1001
super().__init__(code, reason)
class ConnectionClosedOK(ConnectionClosed):
"""
Like :exc:`ConnectionClosed`, when the connection terminated properly.
This means the close code is 1000 (OK) or 1001 (going away).
class InvalidHandshake(Exception):
"""
Exception raised when a handshake request or response is invalid.
def __init__(self, code: int, reason: str) -> None:
assert code == 1000 or code == 1001
super().__init__(code, reason)
class InvalidHandshake(WebSocketException):
"""
Raised during the handshake when the WebSocket connection fails.
"""
class SecurityError(InvalidHandshake):
class AbortHandshake(InvalidHandshake):
"""
Raised when a handshake request or response breaks a security rule.
Security limits are hard coded.
Exception raised to abort a handshake and return a HTTP response.
"""
def __init__(self, status, headers, body=None):
self.status = status
self.headers = headers
self.body = body
class InvalidMessage(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
"""
class InvalidHeader(InvalidHandshake):
"""
Raised when a HTTP header doesn't have a valid format or value.
Exception raised when the HTTP message in a handshake request is malformed.
"""
def __init__(self, name: str, value: Optional[str] = None) -> None:
self.name = name
self.value = value
if value is None:
message = f"missing {name} header"
elif value == "":
message = f"empty {name} header"
else:
message = f"invalid {name} header: {value}"
super().__init__(message)
class InvalidHeaderFormat(InvalidHeader):
class InvalidOrigin(InvalidHandshake):
"""
Raised when a HTTP header cannot be parsed.
The format of the header doesn't match the grammar for that header.
"""
def __init__(self, name: str, error: str, header: str, pos: int) -> None:
self.name = name
error = f"{error} at {pos} in {header}"
super().__init__(name, error)
class InvalidHeaderValue(InvalidHeader):
"""
Raised when a HTTP header has a wrong value.
The format of the header is correct but a value isn't acceptable.
"""
class InvalidOrigin(InvalidHeader):
"""
Raised when the Origin header in a request isn't allowed.
"""
def __init__(self, origin: Optional[str]) -> None:
super().__init__("Origin", origin)
class InvalidUpgrade(InvalidHeader):
"""
Raised when the Upgrade or Connection header isn't correct.
Exception raised when the origin in a handshake request is forbidden.
"""
class InvalidStatusCode(InvalidHandshake):
"""
Raised when a handshake response status code is invalid.
Exception raised when a handshake response status code is invalid.
The integer status code is available in the ``status_code`` attribute.
Provides the integer status code in its ``status_code`` attribute.
"""
def __init__(self, status_code: int) -> None:
def __init__(self, status_code):
self.status_code = status_code
message = f"server rejected WebSocket connection: HTTP {status_code}"
super().__init__(message)
class NegotiationError(InvalidHandshake):
"""
Raised when negotiating an extension fails.
"""
class DuplicateParameter(NegotiationError):
"""
Raised when a parameter name is repeated in an extension header.
"""
def __init__(self, name: str) -> None:
self.name = name
message = f"duplicate parameter: {name}"
message = 'Status code not 101: {}'.format(status_code)
super().__init__(message)
class InvalidParameterName(NegotiationError):
class InvalidState(Exception):
"""
Raised when a parameter name in an extension header is invalid.
Exception raised when an operation is forbidden in the current state.
"""
def __init__(self, name: str) -> None:
self.name = name
message = f"invalid parameter name: {name}"
super().__init__(message)
class InvalidParameterValue(NegotiationError):
"""
Raised when a parameter value in an extension header is invalid.
class ConnectionClosed(InvalidState):
"""
Exception raised when trying to read or write on a closed connection.
def __init__(self, name: str, value: Optional[str]) -> None:
self.name = name
self.value = value
if value is None:
message = f"missing value for parameter {name}"
elif value == "":
message = f"empty value for parameter {name}"
else:
message = f"invalid value for parameter {name}: {value}"
super().__init__(message)
class AbortHandshake(InvalidHandshake):
"""
Raised to abort the handshake on purpose and return a HTTP response.
This exception is an implementation detail.
The public API is :meth:`~server.WebSocketServerProtocol.process_request`.
Provides the connection close code and reason in its ``code`` and
``reason`` attributes respectively.
"""
def __init__(
self, status: http.HTTPStatus, headers: HeadersLike, body: bytes = b""
) -> None:
self.status = status
self.headers = Headers(headers)
self.body = body
message = f"HTTP {status}, {len(self.headers)} headers, {len(body)} bytes"
def __init__(self, code, reason):
self.code = code
self.reason = reason
message = 'WebSocket connection is closed: '
message += 'code = {}, '.format(code) if code else 'no code, '
message += 'reason = {}.'.format(reason) if reason else 'no reason.'
super().__init__(message)
class RedirectHandshake(InvalidHandshake):
class InvalidURI(Exception):
"""
Raised when a handshake gets redirected.
This exception is an implementation detail.
Exception raised when an URI isn't a valid websocket URI.
"""
def __init__(self, uri: str) -> None:
self.uri = uri
def __str__(self) -> str:
return f"redirect to {self.uri}"
class InvalidState(WebSocketException, AssertionError):
class PayloadTooBig(Exception):
"""
Raised when an operation is forbidden in the current state.
This exception is an implementation detail.
It should never be raised in normal circumstances.
Exception raised when a frame's payload exceeds the maximum size.
"""
class InvalidURI(WebSocketException):
class WebSocketProtocolError(Exception):
"""
Raised when connecting to an URI that isn't a valid WebSocket URI.
Internal exception raised when the remote side breaks the protocol.
"""
def __init__(self, uri: str) -> None:
self.uri = uri
message = "{} isn't a valid URI".format(uri)
super().__init__(message)
class PayloadTooBig(WebSocketException):
"""
Raised when receiving a frame with a payload exceeding the maximum size.
"""
class ProtocolError(WebSocketException):
"""
Raised when the other side breaks the protocol.
"""
WebSocketProtocolError = ProtocolError # for backwards compatibility

@ -1,119 +0,0 @@
"""
:mod:`websockets.extensions.base` defines abstract classes for implementing
extensions.
See `section 9 of RFC 6455`_.
.. _section 9 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-9
"""
from typing import List, Optional, Sequence, Tuple
from ..framing import Frame
from ..typing import ExtensionName, ExtensionParameter
__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
class Extension:
"""
Abstract class for extensions.
"""
@property
def name(self) -> ExtensionName:
"""
Extension identifier.
"""
def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame:
"""
Decode an incoming frame.
:param frame: incoming frame
:param max_size: maximum payload size in bytes
"""
def encode(self, frame: Frame) -> Frame:
"""
Encode an outgoing frame.
:param frame: outgoing frame
"""
class ClientExtensionFactory:
"""
Abstract class for client-side extension factories.
"""
@property
def name(self) -> ExtensionName:
"""
Extension identifier.
"""
def get_request_params(self) -> List[ExtensionParameter]:
"""
Build request parameters.
Return a list of ``(name, value)`` pairs.
"""
def process_response_params(
self,
params: Sequence[ExtensionParameter],
accepted_extensions: Sequence[Extension],
) -> Extension:
"""
Process response parameters received from the server.
:param params: list of ``(name, value)`` pairs.
:param accepted_extensions: list of previously accepted extensions.
:raises ~websockets.exceptions.NegotiationError: if parameters aren't
acceptable
"""
class ServerExtensionFactory:
"""
Abstract class for server-side extension factories.
"""
@property
def name(self) -> ExtensionName:
"""
Extension identifier.
"""
def process_request_params(
self,
params: Sequence[ExtensionParameter],
accepted_extensions: Sequence[Extension],
) -> Tuple[List[ExtensionParameter], Extension]:
"""
Process request parameters received from the client.
To accept the offer, return a 2-uple containing:
- response parameters: a list of ``(name, value)`` pairs
- an extension: an instance of a subclass of :class:`Extension`
:param params: list of ``(name, value)`` pairs.
:param accepted_extensions: list of previously accepted extensions.
:raises ~websockets.exceptions.NegotiationError: to reject the offer,
if parameters aren't acceptable
"""

@ -1,588 +0,0 @@
"""
:mod:`websockets.extensions.permessage_deflate` implements the Compression
Extensions for WebSocket as specified in :rfc:`7692`.
"""
import zlib
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from ..exceptions import (
DuplicateParameter,
InvalidParameterName,
InvalidParameterValue,
NegotiationError,
PayloadTooBig,
)
from ..framing import CTRL_OPCODES, OP_CONT, Frame
from ..typing import ExtensionName, ExtensionParameter
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
__all__ = [
"PerMessageDeflate",
"ClientPerMessageDeflateFactory",
"ServerPerMessageDeflateFactory",
]
_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff"
_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]
class PerMessageDeflate(Extension):
"""
Per-Message Deflate extension.
"""
name = ExtensionName("permessage-deflate")
def __init__(
self,
remote_no_context_takeover: bool,
local_no_context_takeover: bool,
remote_max_window_bits: int,
local_max_window_bits: int,
compress_settings: Optional[Dict[Any, Any]] = None,
) -> None:
"""
Configure the Per-Message Deflate extension.
"""
if compress_settings is None:
compress_settings = {}
assert remote_no_context_takeover in [False, True]
assert local_no_context_takeover in [False, True]
assert 8 <= remote_max_window_bits <= 15
assert 8 <= local_max_window_bits <= 15
assert "wbits" not in compress_settings
self.remote_no_context_takeover = remote_no_context_takeover
self.local_no_context_takeover = local_no_context_takeover
self.remote_max_window_bits = remote_max_window_bits
self.local_max_window_bits = local_max_window_bits
self.compress_settings = compress_settings
if not self.remote_no_context_takeover:
self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
if not self.local_no_context_takeover:
self.encoder = zlib.compressobj(
wbits=-self.local_max_window_bits, **self.compress_settings
)
# To handle continuation frames properly, we must keep track of
# whether that initial frame was encoded.
self.decode_cont_data = False
# There's no need for self.encode_cont_data because we always encode
# outgoing frames, so it would always be True.
def __repr__(self) -> str:
return (
f"PerMessageDeflate("
f"remote_no_context_takeover={self.remote_no_context_takeover}, "
f"local_no_context_takeover={self.local_no_context_takeover}, "
f"remote_max_window_bits={self.remote_max_window_bits}, "
f"local_max_window_bits={self.local_max_window_bits})"
)
def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame:
"""
Decode an incoming frame.
"""
# Skip control frames.
if frame.opcode in CTRL_OPCODES:
return frame
# Handle continuation data frames:
# - skip if the initial data frame wasn't encoded
# - reset "decode continuation data" flag if it's a final frame
if frame.opcode == OP_CONT:
if not self.decode_cont_data:
return frame
if frame.fin:
self.decode_cont_data = False
# Handle text and binary data frames:
# - skip if the frame isn't encoded
# - set "decode continuation data" flag if it's a non-final frame
else:
if not frame.rsv1:
return frame
if not frame.fin: # frame.rsv1 is True at this point
self.decode_cont_data = True
# Re-initialize per-message decoder.
if self.remote_no_context_takeover:
self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
# Uncompress compressed frames. Protect against zip bombs by
# preventing zlib from decompressing more than max_length bytes
# (except when the limit is disabled with max_size = None).
data = frame.data
if frame.fin:
data += _EMPTY_UNCOMPRESSED_BLOCK
max_length = 0 if max_size is None else max_size
data = self.decoder.decompress(data, max_length)
if self.decoder.unconsumed_tail:
raise PayloadTooBig(
f"Uncompressed payload length exceeds size limit (? > {max_size} bytes)"
)
# Allow garbage collection of the decoder if it won't be reused.
if frame.fin and self.remote_no_context_takeover:
del self.decoder
return frame._replace(data=data, rsv1=False)
def encode(self, frame: Frame) -> Frame:
"""
Encode an outgoing frame.
"""
# Skip control frames.
if frame.opcode in CTRL_OPCODES:
return frame
# Since we always encode and never fragment messages, there's no logic
# similar to decode() here at this time.
if frame.opcode != OP_CONT:
# Re-initialize per-message decoder.
if self.local_no_context_takeover:
self.encoder = zlib.compressobj(
wbits=-self.local_max_window_bits, **self.compress_settings
)
# Compress data frames.
data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK):
data = data[:-4]
# Allow garbage collection of the encoder if it won't be reused.
if frame.fin and self.local_no_context_takeover:
del self.encoder
return frame._replace(data=data, rsv1=True)
def _build_parameters(
server_no_context_takeover: bool,
client_no_context_takeover: bool,
server_max_window_bits: Optional[int],
client_max_window_bits: Optional[Union[int, bool]],
) -> List[ExtensionParameter]:
"""
Build a list of ``(name, value)`` pairs for some compression parameters.
"""
params: List[ExtensionParameter] = []
if server_no_context_takeover:
params.append(("server_no_context_takeover", None))
if client_no_context_takeover:
params.append(("client_no_context_takeover", None))
if server_max_window_bits:
params.append(("server_max_window_bits", str(server_max_window_bits)))
if client_max_window_bits is True: # only in handshake requests
params.append(("client_max_window_bits", None))
elif client_max_window_bits:
params.append(("client_max_window_bits", str(client_max_window_bits)))
return params
def _extract_parameters(
params: Sequence[ExtensionParameter], *, is_server: bool
) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]:
"""
Extract compression parameters from a list of ``(name, value)`` pairs.
If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided
without a value. This is only allow in handshake requests.
"""
server_no_context_takeover: bool = False
client_no_context_takeover: bool = False
server_max_window_bits: Optional[int] = None
client_max_window_bits: Optional[Union[int, bool]] = None
for name, value in params:
if name == "server_no_context_takeover":
if server_no_context_takeover:
raise DuplicateParameter(name)
if value is None:
server_no_context_takeover = True
else:
raise InvalidParameterValue(name, value)
elif name == "client_no_context_takeover":
if client_no_context_takeover:
raise DuplicateParameter(name)
if value is None:
client_no_context_takeover = True
else:
raise InvalidParameterValue(name, value)
elif name == "server_max_window_bits":
if server_max_window_bits is not None:
raise DuplicateParameter(name)
if value in _MAX_WINDOW_BITS_VALUES:
server_max_window_bits = int(value)
else:
raise InvalidParameterValue(name, value)
elif name == "client_max_window_bits":
if client_max_window_bits is not None:
raise DuplicateParameter(name)
if is_server and value is None: # only in handshake requests
client_max_window_bits = True
elif value in _MAX_WINDOW_BITS_VALUES:
client_max_window_bits = int(value)
else:
raise InvalidParameterValue(name, value)
else:
raise InvalidParameterName(name)
return (
server_no_context_takeover,
client_no_context_takeover,
server_max_window_bits,
client_max_window_bits,
)
class ClientPerMessageDeflateFactory(ClientExtensionFactory):
"""
Client-side extension factory for the Per-Message Deflate extension.
Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to
``True`` to include them in the negotiation offer without a value or to an
integer value to include them with this value.
.. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1
:param server_no_context_takeover: defaults to ``False``
:param client_no_context_takeover: defaults to ``False``
:param server_max_window_bits: optional, defaults to ``None``
:param client_max_window_bits: optional, defaults to ``None``
:param compress_settings: optional, keyword arguments for
:func:`zlib.compressobj`, excluding ``wbits``
"""
name = ExtensionName("permessage-deflate")
def __init__(
self,
server_no_context_takeover: bool = False,
client_no_context_takeover: bool = False,
server_max_window_bits: Optional[int] = None,
client_max_window_bits: Optional[Union[int, bool]] = None,
compress_settings: Optional[Dict[str, Any]] = None,
) -> None:
"""
Configure the Per-Message Deflate extension factory.
"""
if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
raise ValueError("server_max_window_bits must be between 8 and 15")
if not (
client_max_window_bits is None
or client_max_window_bits is True
or 8 <= client_max_window_bits <= 15
):
raise ValueError("client_max_window_bits must be between 8 and 15")
if compress_settings is not None and "wbits" in compress_settings:
raise ValueError(
"compress_settings must not include wbits, "
"set client_max_window_bits instead"
)
self.server_no_context_takeover = server_no_context_takeover
self.client_no_context_takeover = client_no_context_takeover
self.server_max_window_bits = server_max_window_bits
self.client_max_window_bits = client_max_window_bits
self.compress_settings = compress_settings
def get_request_params(self) -> List[ExtensionParameter]:
"""
Build request parameters.
"""
return _build_parameters(
self.server_no_context_takeover,
self.client_no_context_takeover,
self.server_max_window_bits,
self.client_max_window_bits,
)
def process_response_params(
self,
params: Sequence[ExtensionParameter],
accepted_extensions: Sequence["Extension"],
) -> PerMessageDeflate:
"""
Process response parameters.
Return an extension instance.
"""
if any(other.name == self.name for other in accepted_extensions):
raise NegotiationError(f"received duplicate {self.name}")
# Request parameters are available in instance variables.
# Load response parameters in local variables.
(
server_no_context_takeover,
client_no_context_takeover,
server_max_window_bits,
client_max_window_bits,
) = _extract_parameters(params, is_server=False)
# After comparing the request and the response, the final
# configuration must be available in the local variables.
# server_no_context_takeover
#
# Req. Resp. Result
# ------ ------ --------------------------------------------------
# False False False
# False True True
# True False Error!
# True True True
if self.server_no_context_takeover:
if not server_no_context_takeover:
raise NegotiationError("expected server_no_context_takeover")
# client_no_context_takeover
#
# Req. Resp. Result
# ------ ------ --------------------------------------------------
# False False False
# False True True
# True False True - must change value
# True True True
if self.client_no_context_takeover:
if not client_no_context_takeover:
client_no_context_takeover = True
# server_max_window_bits
# Req. Resp. Result
# ------ ------ --------------------------------------------------
# None None None
# None 8≤M≤15 M
# 8≤N≤15 None Error!
# 8≤N≤15 8≤M≤N M
# 8≤N≤15 N<M≤15 Error!
if self.server_max_window_bits is None:
pass
else:
if server_max_window_bits is None:
raise NegotiationError("expected server_max_window_bits")
elif server_max_window_bits > self.server_max_window_bits:
raise NegotiationError("unsupported server_max_window_bits")
# client_max_window_bits
# Req. Resp. Result
# ------ ------ --------------------------------------------------
# None None None
# None 8≤M≤15 Error!
# True None None
# True 8≤M≤15 M
# 8≤N≤15 None N - must change value
# 8≤N≤15 8≤M≤N M
# 8≤N≤15 N<M≤15 Error!
if self.client_max_window_bits is None:
if client_max_window_bits is not None:
raise NegotiationError("unexpected client_max_window_bits")
elif self.client_max_window_bits is True:
pass
else:
if client_max_window_bits is None:
client_max_window_bits = self.client_max_window_bits
elif client_max_window_bits > self.client_max_window_bits:
raise NegotiationError("unsupported client_max_window_bits")
return PerMessageDeflate(
server_no_context_takeover, # remote_no_context_takeover
client_no_context_takeover, # local_no_context_takeover
server_max_window_bits or 15, # remote_max_window_bits
client_max_window_bits or 15, # local_max_window_bits
self.compress_settings,
)
class ServerPerMessageDeflateFactory(ServerExtensionFactory):
"""
Server-side extension factory for the Per-Message Deflate extension.
Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to
``True`` to include them in the negotiation offer without a value or to an
integer value to include them with this value.
.. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1
:param server_no_context_takeover: defaults to ``False``
:param client_no_context_takeover: defaults to ``False``
:param server_max_window_bits: optional, defaults to ``None``
:param client_max_window_bits: optional, defaults to ``None``
:param compress_settings: optional, keyword arguments for
:func:`zlib.compressobj`, excluding ``wbits``
"""
name = ExtensionName("permessage-deflate")
def __init__(
self,
server_no_context_takeover: bool = False,
client_no_context_takeover: bool = False,
server_max_window_bits: Optional[int] = None,
client_max_window_bits: Optional[int] = None,
compress_settings: Optional[Dict[str, Any]] = None,
) -> None:
"""
Configure the Per-Message Deflate extension factory.
"""
if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
raise ValueError("server_max_window_bits must be between 8 and 15")
if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15):
raise ValueError("client_max_window_bits must be between 8 and 15")
if compress_settings is not None and "wbits" in compress_settings:
raise ValueError(
"compress_settings must not include wbits, "
"set server_max_window_bits instead"
)
self.server_no_context_takeover = server_no_context_takeover
self.client_no_context_takeover = client_no_context_takeover
self.server_max_window_bits = server_max_window_bits
self.client_max_window_bits = client_max_window_bits
self.compress_settings = compress_settings
def process_request_params(
self,
params: Sequence[ExtensionParameter],
accepted_extensions: Sequence["Extension"],
) -> Tuple[List[ExtensionParameter], PerMessageDeflate]:
"""
Process request parameters.
Return response params and an extension instance.
"""
if any(other.name == self.name for other in accepted_extensions):
raise NegotiationError(f"skipped duplicate {self.name}")
# Load request parameters in local variables.
(
server_no_context_takeover,
client_no_context_takeover,
server_max_window_bits,
client_max_window_bits,
) = _extract_parameters(params, is_server=True)
# Configuration parameters are available in instance variables.
# After comparing the request and the configuration, the response must
# be available in the local variables.
# server_no_context_takeover
#
# Config Req. Resp.
# ------ ------ --------------------------------------------------
# False False False
# False True True
# True False True - must change value to True
# True True True
if self.server_no_context_takeover:
if not server_no_context_takeover:
server_no_context_takeover = True
# client_no_context_takeover
#
# Config Req. Resp.
# ------ ------ --------------------------------------------------
# False False False
# False True True (or False)
# True False True - must change value to True
# True True True (or False)
if self.client_no_context_takeover:
if not client_no_context_takeover:
client_no_context_takeover = True
# server_max_window_bits
# Config Req. Resp.
# ------ ------ --------------------------------------------------
# None None None
# None 8≤M≤15 M
# 8≤N≤15 None N - must change value
# 8≤N≤15 8≤M≤N M
# 8≤N≤15 N<M≤15 N - must change value
if self.server_max_window_bits is None:
pass
else:
if server_max_window_bits is None:
server_max_window_bits = self.server_max_window_bits
elif server_max_window_bits > self.server_max_window_bits:
server_max_window_bits = self.server_max_window_bits
# client_max_window_bits
# Config Req. Resp.
# ------ ------ --------------------------------------------------
# None None None
# None True None - must change value
# None 8≤M≤15 M (or None)
# 8≤N≤15 None Error!
# 8≤N≤15 True N - must change value
# 8≤N≤15 8≤M≤N M (or None)
# 8≤N≤15 N<M≤15 N
if self.client_max_window_bits is None:
if client_max_window_bits is True:
client_max_window_bits = self.client_max_window_bits
else:
if client_max_window_bits is None:
raise NegotiationError("required client_max_window_bits")
elif client_max_window_bits is True:
client_max_window_bits = self.client_max_window_bits
elif self.client_max_window_bits < client_max_window_bits:
client_max_window_bits = self.client_max_window_bits
return (
_build_parameters(
server_no_context_takeover,
client_no_context_takeover,
server_max_window_bits,
client_max_window_bits,
),
PerMessageDeflate(
client_no_context_takeover, # remote_no_context_takeover
server_no_context_takeover, # local_no_context_takeover
client_max_window_bits or 15, # remote_max_window_bits
server_max_window_bits or 15, # local_max_window_bits
self.compress_settings,
),
)

@ -1,342 +1,210 @@
"""
:mod:`websockets.framing` reads and writes WebSocket frames.
The :mod:`websockets.framing` module implements data framing as specified in
`section 5 of RFC 6455`_.
It deals with a single frame at a time. Anything that depends on the sequence
of frames is implemented in :mod:`websockets.protocol`.
See `section 5 of RFC 6455`_.
.. _section 5 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-5
"""
import asyncio
import collections
import io
import random
import struct
from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple
from .exceptions import PayloadTooBig, ProtocolError
from .typing import Data
from .exceptions import PayloadTooBig, WebSocketProtocolError
try:
from .speedups import apply_mask
except ImportError: # pragma: no cover
except ImportError: # pragma: no cover
from .utils import apply_mask
__all__ = [
"DATA_OPCODES",
"CTRL_OPCODES",
"OP_CONT",
"OP_TEXT",
"OP_BINARY",
"OP_CLOSE",
"OP_PING",
"OP_PONG",
"Frame",
"prepare_data",
"encode_data",
"parse_close",
"serialize_close",
'OP_CONT', 'OP_TEXT', 'OP_BINARY', 'OP_CLOSE', 'OP_PING', 'OP_PONG',
'Frame', 'read_frame', 'write_frame', 'parse_close', 'serialize_close'
]
DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY = 0x00, 0x01, 0x02
CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG = 0x08, 0x09, 0x0A
OP_CONT, OP_TEXT, OP_BINARY = range(0x00, 0x03)
OP_CLOSE, OP_PING, OP_PONG = range(0x08, 0x0b)
CLOSE_CODES = {
1000: "OK",
1001: "going away",
1002: "protocol error",
1003: "unsupported type",
# 1004: - (reserved)
# 1005: no status code (internal)
# 1006: connection closed abnormally (internal)
1007: "invalid data",
1008: "policy violation",
1009: "message too big",
1010: "extension required",
1011: "unexpected error",
# 1015: TLS failure (internal)
}
Frame = collections.namedtuple('Frame', ('fin', 'opcode', 'data'))
Frame.__doc__ = """WebSocket frame.
* ``fin`` is the FIN bit
* ``opcode`` is the opcode
* ``data`` is the payload data
Only these three fields are needed by higher level code. The MASK bit, payload
length and masking-key are handled on the fly by :func:`read_frame` and
:func:`write_frame`.
# Close code that are allowed in a close frame.
# Using a list optimizes `code in EXTERNAL_CLOSE_CODES`.
EXTERNAL_CLOSE_CODES = [1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011]
"""
# Consider converting to a dataclass when dropping support for Python < 3.7.
@asyncio.coroutine
def read_frame(reader, mask, *, max_size=None):
"""
Read a WebSocket frame and return a :class:`Frame` object.
``reader`` is a coroutine taking an integer argument and reading exactly
this number of bytes, unless the end of file is reached.
class Frame(NamedTuple):
"""
WebSocket frame.
``mask`` is a :class:`bool` telling whether the frame should be masked
i.e. whether the read happens on the server side.
:param bool fin: FIN bit
:param bool rsv1: RSV1 bit
:param bool rsv2: RSV2 bit
:param bool rsv3: RSV3 bit
:param int opcode: opcode
:param bytes data: payload data
If ``max_size`` is set and the payload exceeds this size in bytes,
:exc:`~websockets.exceptions.PayloadTooBig` is raised.
Only these fields are needed. The MASK bit, payload length and masking-key
are handled on the fly by :meth:`read` and :meth:`write`.
This function validates the frame before returning it and raises
:exc:`~websockets.exceptions.WebSocketProtocolError` if it contains
incorrect values.
"""
fin: bool
opcode: int
data: bytes
rsv1: bool = False
rsv2: bool = False
rsv3: bool = False
@classmethod
async def read(
cls,
reader: Callable[[int], Awaitable[bytes]],
*,
mask: bool,
max_size: Optional[int] = None,
extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None,
) -> "Frame":
"""
Read a WebSocket frame.
:param reader: coroutine that reads exactly the requested number of
bytes, unless the end of file is reached
:param mask: whether the frame should be masked i.e. whether the read
happens on the server side
:param max_size: maximum payload size in bytes
:param extensions: list of classes with a ``decode()`` method that
transforms the frame and return a new frame; extensions are applied
in reverse order
:raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds
``max_size``
:raises ~websockets.exceptions.ProtocolError: if the frame
contains incorrect values
"""
# Read the header.
data = await reader(2)
head1, head2 = struct.unpack("!BB", data)
# While not Pythonic, this is marginally faster than calling bool().
fin = True if head1 & 0b10000000 else False
rsv1 = True if head1 & 0b01000000 else False
rsv2 = True if head1 & 0b00100000 else False
rsv3 = True if head1 & 0b00010000 else False
opcode = head1 & 0b00001111
if (True if head2 & 0b10000000 else False) != mask:
raise ProtocolError("incorrect masking")
length = head2 & 0b01111111
if length == 126:
data = await reader(2)
(length,) = struct.unpack("!H", data)
elif length == 127:
data = await reader(8)
(length,) = struct.unpack("!Q", data)
if max_size is not None and length > max_size:
raise PayloadTooBig(
f"payload length exceeds size limit ({length} > {max_size} bytes)"
)
if mask:
mask_bits = await reader(4)
# Read the data.
data = await reader(length)
if mask:
data = apply_mask(data, mask_bits)
frame = cls(fin, opcode, data, rsv1, rsv2, rsv3)
if extensions is None:
extensions = []
for extension in reversed(extensions):
frame = extension.decode(frame, max_size=max_size)
frame.check()
return frame
def write(
frame,
write: Callable[[bytes], Any],
*,
mask: bool,
extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None,
) -> None:
"""
Write a WebSocket frame.
:param frame: frame to write
:param write: function that writes bytes
:param mask: whether the frame should be masked i.e. whether the write
happens on the client side
:param extensions: list of classes with an ``encode()`` method that
transform the frame and return a new frame; extensions are applied
in order
:raises ~websockets.exceptions.ProtocolError: if the frame
contains incorrect values
"""
# The first parameter is called `frame` rather than `self`,
# but it's the instance of class to which this method is bound.
frame.check()
if extensions is None:
extensions = []
for extension in extensions:
frame = extension.encode(frame)
output = io.BytesIO()
# Prepare the header.
head1 = (
(0b10000000 if frame.fin else 0)
| (0b01000000 if frame.rsv1 else 0)
| (0b00100000 if frame.rsv2 else 0)
| (0b00010000 if frame.rsv3 else 0)
| frame.opcode
)
head2 = 0b10000000 if mask else 0
length = len(frame.data)
if length < 126:
output.write(struct.pack("!BB", head1, head2 | length))
elif length < 65536:
output.write(struct.pack("!BBH", head1, head2 | 126, length))
else:
output.write(struct.pack("!BBQ", head1, head2 | 127, length))
if mask:
mask_bits = struct.pack("!I", random.getrandbits(32))
output.write(mask_bits)
# Prepare the data.
if mask:
data = apply_mask(frame.data, mask_bits)
else:
data = frame.data
output.write(data)
# Send the frame.
# The frame is written in a single call to write in order to prevent
# TCP fragmentation. See #68 for details. This also makes it safe to
# send frames concurrently from multiple coroutines.
write(output.getvalue())
def check(frame) -> None:
"""
Check that reserved bits and opcode have acceptable values.
:raises ~websockets.exceptions.ProtocolError: if a reserved
bit or the opcode is invalid
"""
# The first parameter is called `frame` rather than `self`,
# but it's the instance of class to which this method is bound.
if frame.rsv1 or frame.rsv2 or frame.rsv3:
raise ProtocolError("reserved bits must be 0")
if frame.opcode in DATA_OPCODES:
return
elif frame.opcode in CTRL_OPCODES:
if len(frame.data) > 125:
raise ProtocolError("control frame too long")
if not frame.fin:
raise ProtocolError("fragmented control frame")
else:
raise ProtocolError(f"invalid opcode: {frame.opcode}")
def prepare_data(data: Data) -> Tuple[int, bytes]:
# Read the header
data = yield from reader(2)
head1, head2 = struct.unpack('!BB', data)
fin = bool(head1 & 0b10000000)
if head1 & 0b01110000:
raise WebSocketProtocolError("Reserved bits must be 0")
opcode = head1 & 0b00001111
if bool(head2 & 0b10000000) != mask:
raise WebSocketProtocolError("Incorrect masking")
length = head2 & 0b01111111
if length == 126:
data = yield from reader(2)
length, = struct.unpack('!H', data)
elif length == 127:
data = yield from reader(8)
length, = struct.unpack('!Q', data)
if max_size is not None and length > max_size:
raise PayloadTooBig("Payload exceeds limit "
"({} > {} bytes)".format(length, max_size))
if mask:
mask_bits = yield from reader(4)
# Read the data
data = yield from reader(length)
if mask:
data = apply_mask(data, mask_bits)
frame = Frame(fin, opcode, data)
check_frame(frame)
return frame
def write_frame(frame, writer, mask):
"""
Convert a string or byte-like object to an opcode and a bytes-like object.
Write a WebSocket frame.
This function is designed for data frames.
``frame`` is the :class:`Frame` object to write.
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
object encoding ``data`` in UTF-8.
``writer`` is a function accepting bytes.
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
object.
``mask`` is a :class:`bool` telling whether the frame should be masked
i.e. whether the write happens on the client side.
:raises TypeError: if ``data`` doesn't have a supported type
This function validates the frame before sending it and raises
:exc:`~websockets.exceptions.WebSocketProtocolError` if it contains
incorrect values.
"""
if isinstance(data, str):
return OP_TEXT, data.encode("utf-8")
elif isinstance(data, (bytes, bytearray)):
return OP_BINARY, data
elif isinstance(data, memoryview):
if data.c_contiguous:
return OP_BINARY, data
else:
return OP_BINARY, data.tobytes()
check_frame(frame)
output = io.BytesIO()
# Prepare the header
head1 = 0b10000000 if frame.fin else 0
head1 |= frame.opcode
head2 = 0b10000000 if mask else 0
length = len(frame.data)
if length < 0x7e:
output.write(struct.pack('!BB', head1, head2 | length))
elif length < 0x10000:
output.write(struct.pack('!BBH', head1, head2 | 126, length))
else:
raise TypeError("data must be bytes-like or str")
def encode_data(data: Data) -> bytes:
"""
Convert a string or byte-like object to bytes.
This function is designed for ping and pong frames.
output.write(struct.pack('!BBQ', head1, head2 | 127, length))
if mask:
mask_bits = struct.pack('!I', random.getrandbits(32))
output.write(mask_bits)
# Prepare the data
if mask:
data = apply_mask(frame.data, mask_bits)
else:
data = frame.data
output.write(data)
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
``data`` in UTF-8.
# Send the frame
writer(output.getvalue())
If ``data`` is a bytes-like object, return a :class:`bytes` object.
:raises TypeError: if ``data`` doesn't have a supported type
def check_frame(frame):
"""
Raise :exc:`~websockets.exceptions.WebSocketProtocolError` if the frame
contains incorrect values.
"""
if isinstance(data, str):
return data.encode("utf-8")
elif isinstance(data, (bytes, bytearray)):
return bytes(data)
elif isinstance(data, memoryview):
return data.tobytes()
if frame.opcode in (OP_CONT, OP_TEXT, OP_BINARY):
return
elif frame.opcode in (OP_CLOSE, OP_PING, OP_PONG):
if len(frame.data) > 125:
raise WebSocketProtocolError("Control frame too long")
if not frame.fin:
raise WebSocketProtocolError("Fragmented control frame")
else:
raise TypeError("data must be bytes-like or str")
raise WebSocketProtocolError("Invalid opcode")
def parse_close(data: bytes) -> Tuple[int, str]:
def parse_close(data):
"""
Parse the payload from a close frame.
Parse the data in a close frame.
Return ``(code, reason)``.
Return ``(code, reason)`` when ``code`` is an :class:`int` and ``reason``
a :class:`str`.
:raises ~websockets.exceptions.ProtocolError: if data is ill-formed
:raises UnicodeDecodeError: if the reason isn't valid UTF-8
Raise :exc:`~websockets.exceptions.WebSocketProtocolError` or
:exc:`UnicodeDecodeError` if the data is invalid.
"""
length = len(data)
if length >= 2:
(code,) = struct.unpack("!H", data[:2])
check_close(code)
reason = data[2:].decode("utf-8")
return code, reason
elif length == 0:
return 1005, ""
if length == 0:
return 1005, ''
elif length == 1:
raise WebSocketProtocolError("Close frame too short")
else:
assert length == 1
raise ProtocolError("close frame too short")
code, = struct.unpack('!H', data[:2])
if not (code in CLOSE_CODES or 3000 <= code < 5000):
raise WebSocketProtocolError("Invalid status code")
reason = data[2:].decode('utf-8')
return code, reason
def serialize_close(code: int, reason: str) -> bytes:
def serialize_close(code, reason):
"""
Serialize the payload for a close frame.
Serialize the data for a close frame.
This is the reverse of :func:`parse_close`.
"""
check_close(code)
return struct.pack("!H", code) + reason.encode("utf-8")
def check_close(code: int) -> None:
"""
Check that the close code has an acceptable value for a close frame.
:raises ~websockets.exceptions.ProtocolError: if the close code
is invalid
"""
if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000):
raise ProtocolError("invalid status code")
# at the bottom to allow circular import, because Extension depends on Frame
import websockets.extensions.base # isort:skip # noqa
return struct.pack('!H', code) + reason.encode('utf-8')

@ -1,10 +1,19 @@
"""
:mod:`websockets.handshake` provides helpers for the WebSocket handshake.
See `section 4 of RFC 6455`_.
The :mod:`websockets.handshake` module deals with the WebSocket opening
handshake according to `section 4 of RFC 6455`_.
.. _section 4 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4
It provides functions to implement the handshake with any existing HTTP
library. You must pass to these functions:
- A ``set_header`` function accepting a header name and a header value,
- A ``get_header`` function accepting a header name and returning the header
value.
The inputs and outputs of ``get_header`` and ``set_header`` are :class:`str`
objects containing only ASCII characters.
Some checks cannot be performed because they depend too much on the
context; instead, they're documented below.
@ -26,160 +35,104 @@ To open a connection, a client must:
"""
import base64
import binascii
import hashlib
import random
from typing import List
from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade
from .http import Headers, MultipleValuesError
from .exceptions import InvalidHandshake
__all__ = ["build_request", "check_request", "build_response", "check_response"]
__all__ = [
'build_request', 'check_request',
'build_response', 'check_response',
]
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
def build_request(headers: Headers) -> str:
def build_request(set_header):
"""
Build a handshake request to send to the server.
Update request headers passed in argument.
:param headers: request headers
:returns: ``key`` which must be passed to :func:`check_response`
Return the ``key`` which must be passed to :func:`check_response`.
"""
raw_key = bytes(random.getrandbits(8) for _ in range(16))
key = base64.b64encode(raw_key).decode()
headers["Upgrade"] = "websocket"
headers["Connection"] = "Upgrade"
headers["Sec-WebSocket-Key"] = key
headers["Sec-WebSocket-Version"] = "13"
rand = bytes(random.getrandbits(8) for _ in range(16))
key = base64.b64encode(rand).decode()
set_header('Upgrade', 'WebSocket')
set_header('Connection', 'Upgrade')
set_header('Sec-WebSocket-Key', key)
set_header('Sec-WebSocket-Version', '13')
return key
def check_request(headers: Headers) -> str:
def check_request(get_header):
"""
Check a handshake request received from the client.
This function doesn't verify that the request is an HTTP/1.1 or higher GET
request and doesn't perform ``Host`` and ``Origin`` checks. These controls
are usually performed earlier in the HTTP request handling code. They're
the responsibility of the caller.
:param headers: request headers
:returns: ``key`` which must be passed to :func:`build_response`
:raises ~websockets.exceptions.InvalidHandshake: if the handshake request
is invalid; then the server must return 400 Bad Request error
If the handshake is valid, this function returns the ``key`` which must be
passed to :func:`build_response`.
"""
connection: List[ConnectionOption] = sum(
[parse_connection(value) for value in headers.get_all("Connection")], []
)
Otherwise it raises an :exc:`~websockets.exceptions.InvalidHandshake`
exception and the server must return an error like 400 Bad Request.
if not any(value.lower() == "upgrade" for value in connection):
raise InvalidUpgrade("Connection", ", ".join(connection))
upgrade: List[UpgradeProtocol] = sum(
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
)
# For compatibility with non-strict implementations, ignore case when
# checking the Upgrade header. It's supposed to be 'WebSocket'.
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
try:
s_w_key = headers["Sec-WebSocket-Key"]
except KeyError:
raise InvalidHeader("Sec-WebSocket-Key")
except MultipleValuesError:
raise InvalidHeader(
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
)
try:
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
except binascii.Error:
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
if len(raw_key) != 16:
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
This function doesn't verify that the request is an HTTP/1.1 or higher GET
request and doesn't perform Host and Origin checks. These controls are
usually performed earlier in the HTTP request handling code. They're the
responsibility of the caller.
"""
try:
s_w_version = headers["Sec-WebSocket-Version"]
except KeyError:
raise InvalidHeader("Sec-WebSocket-Version")
except MultipleValuesError:
raise InvalidHeader(
"Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found"
)
if s_w_version != "13":
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
return s_w_key
def build_response(headers: Headers, key: str) -> None:
assert get_header('Upgrade').lower() == 'websocket'
assert any(
token.strip() == 'upgrade'
for token in get_header('Connection').lower().split(','))
key = get_header('Sec-WebSocket-Key')
assert len(base64.b64decode(key.encode(), validate=True)) == 16
assert get_header('Sec-WebSocket-Version') == '13'
except Exception as exc:
raise InvalidHandshake("Invalid request") from exc
else:
return key
def build_response(set_header, key):
"""
Build a handshake response to send to the client.
Update response headers passed in argument.
:param headers: response headers
:param key: comes from :func:`check_request`
``key`` comes from :func:`check_request`.
"""
headers["Upgrade"] = "websocket"
headers["Connection"] = "Upgrade"
headers["Sec-WebSocket-Accept"] = accept(key)
set_header('Upgrade', 'WebSocket')
set_header('Connection', 'Upgrade')
set_header('Sec-WebSocket-Accept', accept(key))
def check_response(headers: Headers, key: str) -> None:
def check_response(get_header, key):
"""
Check a handshake response received from the server.
``key`` comes from :func:`build_request`.
If the handshake is valid, this function returns ``None``.
Otherwise it raises an :exc:`~websockets.exceptions.InvalidHandshake`
exception.
This function doesn't verify that the response is an HTTP/1.1 or higher
response with a 101 status code. These controls are the responsibility of
the caller.
:param headers: response headers
:param key: comes from :func:`build_request`
:raises ~websockets.exceptions.InvalidHandshake: if the handshake response
is invalid
"""
connection: List[ConnectionOption] = sum(
[parse_connection(value) for value in headers.get_all("Connection")], []
)
if not any(value.lower() == "upgrade" for value in connection):
raise InvalidUpgrade("Connection", " ".join(connection))
upgrade: List[UpgradeProtocol] = sum(
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
)
# For compatibility with non-strict implementations, ignore case when
# checking the Upgrade header. It's supposed to be 'WebSocket'.
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
try:
s_w_accept = headers["Sec-WebSocket-Accept"]
except KeyError:
raise InvalidHeader("Sec-WebSocket-Accept")
except MultipleValuesError:
raise InvalidHeader(
"Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found"
)
if s_w_accept != accept(key):
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
assert get_header('Upgrade').lower() == 'websocket'
assert any(
token.strip() == 'upgrade'
for token in get_header('Connection').lower().split(','))
assert get_header('Sec-WebSocket-Accept') == accept(key)
except Exception as exc:
raise InvalidHandshake("Invalid response") from exc
def accept(key: str) -> str:
def accept(key):
sha1 = hashlib.sha1((key + GUID).encode()).digest()
return base64.b64encode(sha1).decode()

@ -1,515 +0,0 @@
"""
:mod:`websockets.headers` provides parsers and serializers for HTTP headers
used in WebSocket handshake messages.
These APIs cannot be imported from :mod:`websockets`. They must be imported
from :mod:`websockets.headers`.
"""
import base64
import binascii
import re
from typing import Callable, List, NewType, Optional, Sequence, Tuple, TypeVar, cast
from .exceptions import InvalidHeaderFormat, InvalidHeaderValue
from .typing import ExtensionHeader, ExtensionName, ExtensionParameter, Subprotocol
__all__ = [
"parse_connection",
"parse_upgrade",
"parse_extension",
"build_extension",
"parse_subprotocol",
"build_subprotocol",
"build_www_authenticate_basic",
"parse_authorization_basic",
"build_authorization_basic",
]
T = TypeVar("T")
ConnectionOption = NewType("ConnectionOption", str)
UpgradeProtocol = NewType("UpgradeProtocol", str)
# To avoid a dependency on a parsing library, we implement manually the ABNF
# described in https://tools.ietf.org/html/rfc6455#section-9.1 with the
# definitions from https://tools.ietf.org/html/rfc7230#appendix-B.
def peek_ahead(header: str, pos: int) -> Optional[str]:
"""
Return the next character from ``header`` at the given position.
Return ``None`` at the end of ``header``.
We never need to peek more than one character ahead.
"""
return None if pos == len(header) else header[pos]
_OWS_re = re.compile(r"[\t ]*")
def parse_OWS(header: str, pos: int) -> int:
"""
Parse optional whitespace from ``header`` at the given position.
Return the new position.
The whitespace itself isn't returned because it isn't significant.
"""
# There's always a match, possibly empty, whose content doesn't matter.
match = _OWS_re.match(header, pos)
assert match is not None
return match.end()
_token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]:
"""
Parse a token from ``header`` at the given position.
Return the token value and the new position.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
match = _token_re.match(header, pos)
if match is None:
raise InvalidHeaderFormat(header_name, "expected token", header, pos)
return match.group(), match.end()
_quoted_string_re = re.compile(
r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"'
)
_unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])")
def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, int]:
"""
Parse a quoted string from ``header`` at the given position.
Return the unquoted value and the new position.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
match = _quoted_string_re.match(header, pos)
if match is None:
raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos)
return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()
_quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*")
_quote_re = re.compile(r"([\x22\x5c])")
def build_quoted_string(value: str) -> str:
"""
Format ``value`` as a quoted string.
This is the reverse of :func:`parse_quoted_string`.
"""
match = _quotable_re.fullmatch(value)
if match is None:
raise ValueError("invalid characters for quoted-string encoding")
return '"' + _quote_re.sub(r"\\\1", value) + '"'
def parse_list(
parse_item: Callable[[str, int, str], Tuple[T, int]],
header: str,
pos: int,
header_name: str,
) -> List[T]:
"""
Parse a comma-separated list from ``header`` at the given position.
This is appropriate for parsing values with the following grammar:
1#item
``parse_item`` parses one item.
``header`` is assumed not to start or end with whitespace.
(This function is designed for parsing an entire header value and
:func:`~websockets.http.read_headers` strips whitespace from values.)
Return a list of items.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
# Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST
# parse and ignore a reasonable number of empty list elements"; hence
# while loops that remove extra delimiters.
# Remove extra delimiters before the first item.
while peek_ahead(header, pos) == ",":
pos = parse_OWS(header, pos + 1)
items = []
while True:
# Loop invariant: a item starts at pos in header.
item, pos = parse_item(header, pos, header_name)
items.append(item)
pos = parse_OWS(header, pos)
# We may have reached the end of the header.
if pos == len(header):
break
# There must be a delimiter after each element except the last one.
if peek_ahead(header, pos) == ",":
pos = parse_OWS(header, pos + 1)
else:
raise InvalidHeaderFormat(header_name, "expected comma", header, pos)
# Remove extra delimiters before the next item.
while peek_ahead(header, pos) == ",":
pos = parse_OWS(header, pos + 1)
# We may have reached the end of the header.
if pos == len(header):
break
# Since we only advance in the header by one character with peek_ahead()
# or with the end position of a regex match, we can't overshoot the end.
assert pos == len(header)
return items
def parse_connection_option(
header: str, pos: int, header_name: str
) -> Tuple[ConnectionOption, int]:
"""
Parse a Connection option from ``header`` at the given position.
Return the protocol value and the new position.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
item, pos = parse_token(header, pos, header_name)
return cast(ConnectionOption, item), pos
def parse_connection(header: str) -> List[ConnectionOption]:
"""
Parse a ``Connection`` header.
Return a list of HTTP connection options.
:param header: value of the ``Connection`` header
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
return parse_list(parse_connection_option, header, 0, "Connection")
_protocol_re = re.compile(
r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?"
)
def parse_upgrade_protocol(
header: str, pos: int, header_name: str
) -> Tuple[UpgradeProtocol, int]:
"""
Parse an Upgrade protocol from ``header`` at the given position.
Return the protocol value and the new position.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
match = _protocol_re.match(header, pos)
if match is None:
raise InvalidHeaderFormat(header_name, "expected protocol", header, pos)
return cast(UpgradeProtocol, match.group()), match.end()
def parse_upgrade(header: str) -> List[UpgradeProtocol]:
"""
Parse an ``Upgrade`` header.
Return a list of HTTP protocols.
:param header: value of the ``Upgrade`` header
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
return parse_list(parse_upgrade_protocol, header, 0, "Upgrade")
def parse_extension_item_param(
header: str, pos: int, header_name: str
) -> Tuple[ExtensionParameter, int]:
"""
Parse a single extension parameter from ``header`` at the given position.
Return a ``(name, value)`` pair and the new position.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
# Extract parameter name.
name, pos = parse_token(header, pos, header_name)
pos = parse_OWS(header, pos)
# Extract parameter value, if there is one.
value: Optional[str] = None
if peek_ahead(header, pos) == "=":
pos = parse_OWS(header, pos + 1)
if peek_ahead(header, pos) == '"':
pos_before = pos # for proper error reporting below
value, pos = parse_quoted_string(header, pos, header_name)
# https://tools.ietf.org/html/rfc6455#section-9.1 says: the value
# after quoted-string unescaping MUST conform to the 'token' ABNF.
if _token_re.fullmatch(value) is None:
raise InvalidHeaderFormat(
header_name, "invalid quoted header content", header, pos_before
)
else:
value, pos = parse_token(header, pos, header_name)
pos = parse_OWS(header, pos)
return (name, value), pos
def parse_extension_item(
header: str, pos: int, header_name: str
) -> Tuple[ExtensionHeader, int]:
"""
Parse an extension definition from ``header`` at the given position.
Return an ``(extension name, parameters)`` pair, where ``parameters`` is a
list of ``(name, value)`` pairs, and the new position.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
# Extract extension name.
name, pos = parse_token(header, pos, header_name)
pos = parse_OWS(header, pos)
# Extract all parameters.
parameters = []
while peek_ahead(header, pos) == ";":
pos = parse_OWS(header, pos + 1)
parameter, pos = parse_extension_item_param(header, pos, header_name)
parameters.append(parameter)
return (cast(ExtensionName, name), parameters), pos
def parse_extension(header: str) -> List[ExtensionHeader]:
"""
Parse a ``Sec-WebSocket-Extensions`` header.
Return a list of WebSocket extensions and their parameters in this format::
[
(
'extension name',
[
('parameter name', 'parameter value'),
....
]
),
...
]
Parameter values are ``None`` when no value is provided.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions")
parse_extension_list = parse_extension # alias for backwards compatibility
def build_extension_item(
name: ExtensionName, parameters: List[ExtensionParameter]
) -> str:
"""
Build an extension definition.
This is the reverse of :func:`parse_extension_item`.
"""
return "; ".join(
[cast(str, name)]
+ [
# Quoted strings aren't necessary because values are always tokens.
name if value is None else f"{name}={value}"
for name, value in parameters
]
)
def build_extension(extensions: Sequence[ExtensionHeader]) -> str:
"""
Build a ``Sec-WebSocket-Extensions`` header.
This is the reverse of :func:`parse_extension`.
"""
return ", ".join(
build_extension_item(name, parameters) for name, parameters in extensions
)
build_extension_list = build_extension # alias for backwards compatibility
def parse_subprotocol_item(
header: str, pos: int, header_name: str
) -> Tuple[Subprotocol, int]:
"""
Parse a subprotocol from ``header`` at the given position.
Return the subprotocol value and the new position.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
item, pos = parse_token(header, pos, header_name)
return cast(Subprotocol, item), pos
def parse_subprotocol(header: str) -> List[Subprotocol]:
"""
Parse a ``Sec-WebSocket-Protocol`` header.
Return a list of WebSocket subprotocols.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol")
parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility
def build_subprotocol(protocols: Sequence[Subprotocol]) -> str:
"""
Build a ``Sec-WebSocket-Protocol`` header.
This is the reverse of :func:`parse_subprotocol`.
"""
return ", ".join(protocols)
build_subprotocol_list = build_subprotocol # alias for backwards compatibility
def build_www_authenticate_basic(realm: str) -> str:
"""
Build a ``WWW-Authenticate`` header for HTTP Basic Auth.
:param realm: authentication realm
"""
# https://tools.ietf.org/html/rfc7617#section-2
realm = build_quoted_string(realm)
charset = build_quoted_string("UTF-8")
return f"Basic realm={realm}, charset={charset}"
_token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*")
def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]:
"""
Parse a token68 from ``header`` at the given position.
Return the token value and the new position.
:raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs.
"""
match = _token68_re.match(header, pos)
if match is None:
raise InvalidHeaderFormat(header_name, "expected token68", header, pos)
return match.group(), match.end()
def parse_end(header: str, pos: int, header_name: str) -> None:
"""
Check that parsing reached the end of header.
"""
if pos < len(header):
raise InvalidHeaderFormat(header_name, "trailing data", header, pos)
def parse_authorization_basic(header: str) -> Tuple[str, str]:
"""
Parse an ``Authorization`` header for HTTP Basic Auth.
Return a ``(username, password)`` tuple.
:param header: value of the ``Authorization`` header
:raises InvalidHeaderFormat: on invalid inputs
:raises InvalidHeaderValue: on unsupported inputs
"""
# https://tools.ietf.org/html/rfc7235#section-2.1
# https://tools.ietf.org/html/rfc7617#section-2
scheme, pos = parse_token(header, 0, "Authorization")
if scheme.lower() != "basic":
raise InvalidHeaderValue("Authorization", f"unsupported scheme: {scheme}")
if peek_ahead(header, pos) != " ":
raise InvalidHeaderFormat(
"Authorization", "expected space after scheme", header, pos
)
pos += 1
basic_credentials, pos = parse_token68(header, pos, "Authorization")
parse_end(header, pos, "Authorization")
try:
user_pass = base64.b64decode(basic_credentials.encode()).decode()
except binascii.Error:
raise InvalidHeaderValue(
"Authorization", "expected base64-encoded credentials"
) from None
try:
username, password = user_pass.split(":", 1)
except ValueError:
raise InvalidHeaderValue(
"Authorization", "expected username:password credentials"
) from None
return username, password
def build_authorization_basic(username: str, password: str) -> str:
"""
Build an ``Authorization`` header for HTTP Basic Auth.
This is the reverse of :func:`parse_authorization_basic`.
"""
# https://tools.ietf.org/html/rfc7617#section-2
assert ":" not in username
user_pass = f"{username}:{password}"
basic_credentials = base64.b64encode(user_pass.encode()).decode()
return "Basic " + basic_credentials

@ -1,57 +1,36 @@
"""
:mod:`websockets.http` module provides basic HTTP/1.1 support. It is merely
:adequate for WebSocket handshake messages.
The :mod:`websockets.http` module provides HTTP parsing functions. They're
merely adequate for the WebSocket handshake messages.
These APIs cannot be imported from :mod:`websockets`. They must be imported
from :mod:`websockets.http`.
These functions cannot be imported from :mod:`websockets`; they must be
imported from :mod:`websockets.http`.
"""
import asyncio
import http.client
import re
import sys
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
Tuple,
Union,
)
from .version import version as websockets_version
__all__ = [
"read_request",
"read_response",
"Headers",
"MultipleValuesError",
"USER_AGENT",
]
__all__ = ['read_request', 'read_response', 'USER_AGENT']
MAX_HEADERS = 256
MAX_LINE = 4096
USER_AGENT = f"Python/{sys.version[:3]} websockets/{websockets_version}"
def d(value: bytes) -> str:
"""
Decode a bytestring for interpolating into an error message.
"""
return value.decode(errors="backslashreplace")
USER_AGENT = ' '.join((
'Python/{}'.format(sys.version[:3]),
'websockets/{}'.format(websockets_version),
))
# See https://tools.ietf.org/html/rfc7230#appendix-B.
# Regex for validating header names.
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
_token_re = re.compile(rb'^[-!#$%&\'*+.^_`|~0-9a-zA-Z]+$')
# Regex for validating header values.
@ -64,26 +43,28 @@ _token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
_value_re = re.compile(rb'^[\x09\x20-\x7e\x80-\xff]*$')
async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]:
@asyncio.coroutine
def read_request(stream):
"""
Read an HTTP/1.1 GET request and return ``(path, headers)``.
Read an HTTP/1.1 GET request from ``stream``.
``stream`` is an :class:`~asyncio.StreamReader`.
Return ``(path, headers)`` where ``path`` is a :class:`str` and
``headers`` is a list of ``(name, value)`` tuples.
``path`` isn't URL-decoded or validated in any way.
``path`` and ``headers`` are expected to contain only ASCII characters.
Other characters are represented with surrogate escapes.
Non-ASCII characters are represented with surrogate escapes.
:func:`read_request` doesn't attempt to read the request body because
WebSocket handshake requests don't have one. If the request contains a
body, it may be read from ``stream`` after this coroutine returns.
Raise an exception if the request isn't well formatted.
:param stream: input to read the request from
:raises EOFError: if the connection is closed without a full HTTP request
:raises SecurityError: if the request exceeds a security limit
:raises ValueError: if the request isn't well formatted
Don't attempt to read the request body because WebSocket handshake
requests don't have one. If the request contains a body, it may be
read from ``stream`` after this coroutine returns.
"""
# https://tools.ietf.org/html/rfc7230#section-3.1.1
@ -92,42 +73,41 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, "Headers"]:
# version and because path isn't checked. Since WebSocket software tends
# to implement HTTP/1.1 strictly, there's little need for lenient parsing.
try:
request_line = await read_line(stream)
except EOFError as exc:
raise EOFError("connection closed while reading HTTP request line") from exc
# Given the implementation of read_line(), request_line ends with CRLF.
request_line = yield from read_line(stream)
try:
method, raw_path, version = request_line.split(b" ", 2)
except ValueError: # not enough values to unpack (expected 3, got 1-2)
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
# This may raise "ValueError: not enough values to unpack"
method, path, version = request_line[:-2].split(b' ', 2)
if method != b"GET":
raise ValueError(f"unsupported HTTP method: {d(method)}")
if version != b"HTTP/1.1":
raise ValueError(f"unsupported HTTP version: {d(version)}")
path = raw_path.decode("ascii", "surrogateescape")
if method != b'GET':
raise ValueError("Unsupported HTTP method: %r" % method)
if version != b'HTTP/1.1':
raise ValueError("Unsupported HTTP version: %r" % version)
headers = await read_headers(stream)
path = path.decode('ascii', 'surrogateescape')
headers = yield from read_headers(stream)
return path, headers
async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, "Headers"]:
@asyncio.coroutine
def read_response(stream):
"""
Read an HTTP/1.1 response and return ``(status_code, reason, headers)``.
Read an HTTP/1.1 response from ``stream``.
``stream`` is an :class:`~asyncio.StreamReader`.
``reason`` and ``headers`` are expected to contain only ASCII characters.
Other characters are represented with surrogate escapes.
Return ``(status_code, headers)`` where ``status_code`` is a :class:`int`
and ``headers`` is a list of ``(name, value)`` tuples.
Non-ASCII characters are represented with surrogate escapes.
:func:`read_request` doesn't attempt to read the response body because
WebSocket handshake responses don't have one. If the response contains a
body, it may be read from ``stream`` after this coroutine returns.
Raise an exception if the response isn't well formatted.
:param stream: input to read the response from
:raises EOFError: if the connection is closed without a full HTTP response
:raises SecurityError: if the response exceeds a security limit
:raises ValueError: if the response isn't well formatted
Don't attempt to read the response body, because WebSocket handshake
responses don't have one. If the response contains a body, it may be
read from ``stream`` after this coroutine returns.
"""
# https://tools.ietf.org/html/rfc7230#section-3.1.2
@ -135,37 +115,36 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, "Header
# As in read_request, parsing is simple because a fixed value is expected
# for version, status_code is a 3-digit number, and reason can be ignored.
try:
status_line = await read_line(stream)
except EOFError as exc:
raise EOFError("connection closed while reading HTTP status line") from exc
try:
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
except ValueError: # not enough values to unpack (expected 3, got 1-2)
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
if version != b"HTTP/1.1":
raise ValueError(f"unsupported HTTP version: {d(version)}")
try:
status_code = int(raw_status_code)
except ValueError: # invalid literal for int() with base 10
raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None
# Given the implementation of read_line(), status_line ends with CRLF.
status_line = yield from read_line(stream)
# This may raise "ValueError: not enough values to unpack"
version, status_code, reason = status_line[:-2].split(b' ', 2)
if version != b'HTTP/1.1':
raise ValueError("Unsupported HTTP version: %r" % version)
# This may raise "ValueError: invalid literal for int() with base 10"
status_code = int(status_code)
if not 100 <= status_code < 1000:
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
if not _value_re.fullmatch(raw_reason):
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
reason = raw_reason.decode()
raise ValueError("Unsupported HTTP status_code code: %d" % status_code)
if not _value_re.match(reason):
raise ValueError("Invalid HTTP reason phrase: %r" % reason)
headers = await read_headers(stream)
headers = yield from read_headers(stream)
return status_code, reason, headers
return status_code, headers
async def read_headers(stream: asyncio.StreamReader) -> "Headers":
@asyncio.coroutine
def read_headers(stream):
"""
Read HTTP headers from ``stream``.
``stream`` is an :class:`~asyncio.StreamReader`.
Return ``(start_line, headers)`` where ``start_line`` is :class:`bytes`
and ``headers`` is a list of ``(name, value)`` tuples.
Non-ASCII characters are represented with surrogate escapes.
"""
@ -173,188 +152,57 @@ async def read_headers(stream: asyncio.StreamReader) -> "Headers":
# We don't attempt to support obsolete line folding.
headers = Headers()
for _ in range(MAX_HEADERS + 1):
try:
line = await read_line(stream)
except EOFError as exc:
raise EOFError("connection closed while reading HTTP headers") from exc
if line == b"":
headers = []
for _ in range(MAX_HEADERS):
line = yield from read_line(stream)
if line == b'\r\n':
break
try:
raw_name, raw_value = line.split(b":", 1)
except ValueError: # not enough values to unpack (expected 2, got 1)
raise ValueError(f"invalid HTTP header line: {d(line)}") from None
if not _token_re.fullmatch(raw_name):
raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
raw_value = raw_value.strip(b" \t")
if not _value_re.fullmatch(raw_value):
raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
# This may raise "ValueError: not enough values to unpack"
name, value = line[:-2].split(b':', 1)
if not _token_re.match(name):
raise ValueError("Invalid HTTP header name: %r" % name)
value = value.strip(b' \t')
if not _value_re.match(value):
raise ValueError("Invalid HTTP header value: %r" % value)
name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
value = raw_value.decode("ascii", "surrogateescape")
headers[name] = value
headers.append((
name.decode('ascii'), # guaranteed to be ASCII at this point
value.decode('ascii', 'surrogateescape'),
))
else:
raise websockets.exceptions.SecurityError("too many HTTP headers")
raise ValueError("Too many HTTP headers")
return headers
async def read_line(stream: asyncio.StreamReader) -> bytes:
@asyncio.coroutine
def read_line(stream):
"""
Read a single line from ``stream``.
CRLF is stripped from the return value.
``stream`` is an :class:`~asyncio.StreamReader`.
"""
# Security: this is bounded by the StreamReader's limit (default = 32 KiB).
line = await stream.readline()
# Security: this guarantees header values are small (hard-coded = 4 KiB)
# Security: this is bounded by the StreamReader's limit (default = 32kB).
line = yield from stream.readline()
# Security: this guarantees header values are small (hardcoded = 4kB)
if len(line) > MAX_LINE:
raise websockets.exceptions.SecurityError("line too long")
raise ValueError("Line too long")
# Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5
if not line.endswith(b"\r\n"):
raise EOFError("line without CRLF")
return line[:-2]
if not line.endswith(b'\r\n'):
raise ValueError("Line without CRLF")
return line
class MultipleValuesError(LookupError):
"""
Exception raised when :class:`Headers` has more than one value for a key.
def build_headers(raw_headers):
"""
Build a date structure for HTTP headers from a list of name - value pairs.
def __str__(self) -> str:
# Implement the same logic as KeyError_str in Objects/exceptions.c.
if len(self.args) == 1:
return repr(self.args[0])
return super().__str__()
See also https://github.com/aaugustin/websockets/issues/210.
class Headers(MutableMapping[str, str]):
"""
Efficient data structure for manipulating HTTP headers.
A :class:`list` of ``(name, values)`` is inefficient for lookups.
A :class:`dict` doesn't suffice because header names are case-insensitive
and multiple occurrences of headers with the same name are possible.
:class:`Headers` stores HTTP headers in a hybrid data structure to provide
efficient insertions and lookups while preserving the original data.
In order to account for multiple values with minimal hassle,
:class:`Headers` follows this logic:
- When getting a header with ``headers[name]``:
- if there's no value, :exc:`KeyError` is raised;
- if there's exactly one value, it's returned;
- if there's more than one value, :exc:`MultipleValuesError` is raised.
- When setting a header with ``headers[name] = value``, the value is
appended to the list of values for that header.
- When deleting a header with ``del headers[name]``, all values for that
header are removed (this is slow).
Other methods for manipulating headers are consistent with this logic.
As long as no header occurs multiple times, :class:`Headers` behaves like
:class:`dict`, except keys are lower-cased to provide case-insensitivity.
Two methods support support manipulating multiple values explicitly:
- :meth:`get_all` returns a list of all values for a header;
- :meth:`raw_items` returns an iterator of ``(name, values)`` pairs.
"""
__slots__ = ["_dict", "_list"]
def __init__(self, *args: Any, **kwargs: str) -> None:
self._dict: Dict[str, List[str]] = {}
self._list: List[Tuple[str, str]] = []
# MutableMapping.update calls __setitem__ for each (name, value) pair.
self.update(*args, **kwargs)
def __str__(self) -> str:
return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n"
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._list!r})"
def copy(self) -> "Headers":
copy = self.__class__()
copy._dict = self._dict.copy()
copy._list = self._list.copy()
return copy
# Collection methods
def __contains__(self, key: object) -> bool:
return isinstance(key, str) and key.lower() in self._dict
def __iter__(self) -> Iterator[str]:
return iter(self._dict)
def __len__(self) -> int:
return len(self._dict)
# MutableMapping methods
def __getitem__(self, key: str) -> str:
value = self._dict[key.lower()]
if len(value) == 1:
return value[0]
else:
raise MultipleValuesError(key)
def __setitem__(self, key: str, value: str) -> None:
self._dict.setdefault(key.lower(), []).append(value)
self._list.append((key, value))
def __delitem__(self, key: str) -> None:
key_lower = key.lower()
self._dict.__delitem__(key_lower)
# This is inefficent. Fortunately deleting HTTP headers is uncommon.
self._list = [(k, v) for k, v in self._list if k.lower() != key_lower]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Headers):
return NotImplemented
return self._list == other._list
def clear(self) -> None:
"""
Remove all headers.
"""
self._dict = {}
self._list = []
# Methods for handling multiple values
def get_all(self, key: str) -> List[str]:
"""
Return the (possibly empty) list of all values for a header.
:param key: header name
"""
return self._dict.get(key.lower(), [])
def raw_items(self) -> Iterator[Tuple[str, str]]:
"""
Return an iterator of all values as ``(name, value)`` pairs.
"""
return iter(self._list)
HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]]
# at the bottom to allow circular import, because AbortHandshake depends on HeadersLike
import websockets.exceptions # isort:skip # noqa
headers = http.client.HTTPMessage()
headers._headers = raw_headers # HACK
return headers

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save