Compare commits

...

13 Commits

Author SHA1 Message Date
dirkf
47214e46d8 [compat] Fix old Pythons broken loading of valueless cookie attributes
Cookie string parsing in Py 2.6.9, probably earlier, requires `=`.
Also 3.2, though the CPython code appears to be OK: 3.1 was also wrong.
2023-07-18 10:50:46 +01:00
dirkf
1d8d5a93f7 [test] Fixes for old Pythons 2023-07-18 10:50:46 +01:00
dirkf
1634b1d61e [doc] Warn against setting cookies with --add-header 2023-07-18 10:50:46 +01:00
bashonly
21438a4194 [downloader/external] Fix cookie support 2023-07-18 10:50:46 +01:00
Simon Sawicki
8334ec961b [core] Process header cookies on loading 2023-07-18 10:50:46 +01:00
bashonly
3801d36416 [utils] YoutubeDLCookieJar: Add get_cookie_header and get_cookies_for_url methods 2023-07-18 10:50:46 +01:00
dirkf
b383be9887 [core] Remove Cookie header on redirect to prevent leaks
Adated from yt-dlp/yt-dlp-ghsa-v8mc-9377-rwjj/pull/1/commits/101caac
Thx coletdjnz
2023-07-18 10:50:46 +01:00
dirkf
46fde7caee [core] Update redirect handling from yt-dlp
* Thx coletdjnz: https://github.com/yt-dlp/yt-dlp/pull/7094
* add test that redirected `POST` loses its `Content-Type`
2023-07-18 10:50:46 +01:00
dirkf
648dc5304c [compat] Add Request and HTTPClient compat for redirect
* support `method` parameter of `Request.__init__`  (Py 2 and old Py 3)
* support `getcode` method of compat_http_client.HTTPResponse (Py 2)
2023-07-18 10:50:46 +01:00
dirkf
1720c04dc5 [test] Make skipped tests in test_execution work with Py 2.6 2023-07-18 10:50:46 +01:00
dirkf
d5ef405c5d [core] Align error reporting methods with yt-dlp 2023-07-18 10:50:46 +01:00
dirkf
f47fdb9564 [utils] Add {expected_type} and Iterable support to traverse_obj() 2023-07-18 10:50:46 +01:00
dirkf
b6dff4073d [core] Revert version display from b8a86dc 2023-07-18 10:50:46 +01:00
15 changed files with 1524 additions and 276 deletions

View File

@ -301,7 +301,7 @@ jobs:
if: ${{ matrix.python-version == '2.6' }}
shell: bash
run: |
# see pip for Jython
# Work around deprecation of support for non-SNI clients at PyPI CDN (see https://status.python.org/incidents/hzmjhqsdjqgb)
$PIP -qq show unittest2 || { \
for u in "65/26/32b8464df2a97e6dd1b656ed26b2c194606c16fe163c695a992b36c11cdf/six-1.13.0-py2.py3-none-any.whl" \
"f2/94/3af39d34be01a24a6e65433d19e107099374224905f1e0cc6bbe1fd22a2f/argparse-1.4.0-py2.py3-none-any.whl" \
@ -312,7 +312,7 @@ jobs:
$PIP install ${u##*/}; \
done; }
# make tests use unittest2
for test in ./test/test_*.py; do
for test in ./test/test_*.py ./test/helper.py; do
sed -r -i -e '/^import unittest$/s/test/test2 as unittest/' "$test"
done
#-------- nose --------

View File

@ -9,6 +9,7 @@ import re
import types
import ssl
import sys
import unittest
import youtube_dl.extractor
from youtube_dl import YoutubeDL
@ -17,6 +18,7 @@ from youtube_dl.compat import (
compat_str,
)
from youtube_dl.utils import (
IDENTITY,
preferredencoding,
write_string,
)
@ -72,7 +74,8 @@ class FakeYDL(YoutubeDL):
def to_screen(self, s, skip_eol=None):
print(s)
def trouble(self, s, tb=None):
def trouble(self, *args, **kwargs):
s = args[0] if len(args) > 0 else kwargs.get('message', 'Missing message')
raise Exception(s)
def download(self, x):
@ -297,3 +300,7 @@ def http_server_port(httpd):
else:
sock = httpd.socket
return sock.getsockname()[1]
def expectedFailureIf(cond):
return unittest.expectedFailure if cond else IDENTITY

View File

@ -10,14 +10,30 @@ import unittest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import copy
import json
from test.helper import FakeYDL, assertRegexpMatches
from test.helper import (
FakeYDL,
assertRegexpMatches,
try_rm,
)
from youtube_dl import YoutubeDL
from youtube_dl.compat import compat_str, compat_urllib_error
from youtube_dl.compat import (
compat_http_cookiejar_Cookie,
compat_http_cookies_SimpleCookie,
compat_kwargs,
compat_str,
compat_urllib_error,
)
from youtube_dl.extractor import YoutubeIE
from youtube_dl.extractor.common import InfoExtractor
from youtube_dl.postprocessor.common import PostProcessor
from youtube_dl.utils import ExtractorError, match_filter_func
from youtube_dl.utils import (
ExtractorError,
match_filter_func,
traverse_obj,
)
TEST_URL = 'http://localhost/sample.mp4'
@ -29,11 +45,14 @@ class YDL(FakeYDL):
self.msgs = []
def process_info(self, info_dict):
self.downloaded_info_dicts.append(info_dict)
self.downloaded_info_dicts.append(info_dict.copy())
def to_screen(self, msg):
self.msgs.append(msg)
def dl(self, *args, **kwargs):
assert False, 'Downloader must not be invoked for test_YoutubeDL'
def _make_result(formats, **kwargs):
res = {
@ -42,8 +61,9 @@ def _make_result(formats, **kwargs):
'title': 'testttitle',
'extractor': 'testex',
'extractor_key': 'TestEx',
'webpage_url': 'http://example.com/watch?v=shenanigans',
}
res.update(**kwargs)
res.update(**compat_kwargs(kwargs))
return res
@ -930,17 +950,11 @@ class TestYoutubeDL(unittest.TestCase):
# Test case for https://github.com/ytdl-org/youtube-dl/issues/27064
def test_ignoreerrors_for_playlist_with_url_transparent_iterable_entries(self):
class _YDL(YDL):
def __init__(self, *args, **kwargs):
super(_YDL, self).__init__(*args, **kwargs)
def trouble(self, s, tb=None):
pass
ydl = _YDL({
ydl = YDL({
'format': 'extra',
'ignoreerrors': True,
})
ydl.trouble = lambda *_, **__: None
class VideoIE(InfoExtractor):
_VALID_URL = r'video:(?P<id>\d+)'
@ -1017,5 +1031,160 @@ class TestYoutubeDL(unittest.TestCase):
self.assertEqual(out_info['release_date'], '20210930')
class TestYoutubeDLCookies(unittest.TestCase):
@staticmethod
def encode_cookie(cookie):
if not isinstance(cookie, dict):
cookie = vars(cookie)
for name, value in cookie.items():
yield name, compat_str(value)
@classmethod
def comparable_cookies(cls, cookies):
# Work around cookiejar cookies not being unicode strings
return sorted(map(tuple, map(sorted, map(cls.encode_cookie, cookies))))
def assertSameCookies(self, c1, c2, msg=None):
return self.assertEqual(
*map(self.comparable_cookies, (c1, c2)),
msg=msg)
def assertSameCookieStrings(self, c1, c2, msg=None):
return self.assertSameCookies(
*map(lambda c: compat_http_cookies_SimpleCookie(c).values(), (c1, c2)),
msg=msg)
def test_header_cookies(self):
ydl = FakeYDL()
ydl.report_warning = lambda *_, **__: None
def cookie(name, value, version=None, domain='', path='', secure=False, expires=None):
return compat_http_cookiejar_Cookie(
version or 0, name, value, None, False,
domain, bool(domain), bool(domain), path, bool(path),
secure, expires, False, None, None, rest={})
test_url, test_domain = (t % ('yt.dl',) for t in ('https://%s/test', '.%s'))
def test(encoded_cookies, cookies, headers=False, round_trip=None, error_re=None):
def _test():
ydl.cookiejar.clear()
ydl._load_cookies(encoded_cookies, autoscope=headers)
if headers:
ydl._apply_header_cookies(test_url)
data = {'url': test_url}
ydl._calc_headers(data)
self.assertSameCookies(
cookies, ydl.cookiejar,
'Extracted cookiejar.Cookie is not the same')
if not headers:
self.assertSameCookieStrings(
data.get('cookies'), round_trip or encoded_cookies,
msg='Cookie is not the same as round trip')
ydl.__dict__['_YoutubeDL__header_cookies'] = []
try:
_test()
except AssertionError:
raise
except Exception as e:
if not error_re:
raise
assertRegexpMatches(self, e.args[0], error_re.join(('.*',) * 2))
test('test=value; Domain=' + test_domain, [cookie('test', 'value', domain=test_domain)])
test('test=value', [cookie('test', 'value')], error_re='Unscoped cookies are not allowed')
test('cookie1=value1; Domain={0}; Path=/test; cookie2=value2; Domain={0}; Path=/'.format(test_domain), [
cookie('cookie1', 'value1', domain=test_domain, path='/test'),
cookie('cookie2', 'value2', domain=test_domain, path='/')])
cookie_kw = compat_kwargs(
{'domain': test_domain, 'path': '/test', 'secure': True, 'expires': '9999999999', })
test('test=value; Domain={domain}; Path={path}; Secure; Expires={expires}'.format(**cookie_kw), [
cookie('test', 'value', **cookie_kw)])
test('test="value; "; path=/test; domain=' + test_domain, [
cookie('test', 'value; ', domain=test_domain, path='/test')],
round_trip='test="value\\073 "; Domain={0}; Path=/test'.format(test_domain))
test('name=; Domain=' + test_domain, [cookie('name', '', domain=test_domain)],
round_trip='name=""; Domain=' + test_domain)
test('test=value', [cookie('test', 'value', domain=test_domain)], headers=True)
test('cookie1=value; Domain={0}; cookie2=value'.format(test_domain), [],
headers=True, error_re='Invalid syntax')
ydl.report_warning = ydl.report_error
test('test=value', [], headers=True, error_re='Passing cookies as a header is a potential security risk')
def test_infojson_cookies(self):
TEST_FILE = 'test_infojson_cookies.info.json'
TEST_URL = 'https://example.com/example.mp4'
COOKIES = 'a=b; Domain=.example.com; c=d; Domain=.example.com'
COOKIE_HEADER = {'Cookie': 'a=b; c=d'}
ydl = FakeYDL()
ydl.process_info = lambda x: ydl._write_info_json('test', x, TEST_FILE)
def make_info(info_header_cookies=False, fmts_header_cookies=False, cookies_field=False):
fmt = {'url': TEST_URL}
if fmts_header_cookies:
fmt['http_headers'] = COOKIE_HEADER
if cookies_field:
fmt['cookies'] = COOKIES
return _make_result([fmt], http_headers=COOKIE_HEADER if info_header_cookies else None)
def test(initial_info, note):
def failure_msg(why):
return ' when '.join((why, note))
result = {}
result['processed'] = ydl.process_ie_result(initial_info)
self.assertTrue(ydl.cookiejar.get_cookies_for_url(TEST_URL),
msg=failure_msg('No cookies set in cookiejar after initial process'))
ydl.cookiejar.clear()
with open(TEST_FILE) as infojson:
result['loaded'] = ydl.sanitize_info(json.load(infojson), True)
result['final'] = ydl.process_ie_result(result['loaded'].copy(), download=False)
self.assertTrue(ydl.cookiejar.get_cookies_for_url(TEST_URL),
msg=failure_msg('No cookies set in cookiejar after final process'))
ydl.cookiejar.clear()
for key in ('processed', 'loaded', 'final'):
info = result[key]
self.assertIsNone(
traverse_obj(info, ((None, ('formats', 0)), 'http_headers', 'Cookie'), casesense=False, get_all=False),
msg=failure_msg('Cookie header not removed in {0} result'.format(key)))
self.assertSameCookieStrings(
traverse_obj(info, ((None, ('formats', 0)), 'cookies'), get_all=False), COOKIES,
msg=failure_msg('No cookies field found in {0} result'.format(key)))
test({'url': TEST_URL, 'http_headers': COOKIE_HEADER, 'id': '1', 'title': 'x'}, 'no formats field')
test(make_info(info_header_cookies=True), 'info_dict header cokies')
test(make_info(fmts_header_cookies=True), 'format header cookies')
test(make_info(info_header_cookies=True, fmts_header_cookies=True), 'info_dict and format header cookies')
test(make_info(info_header_cookies=True, fmts_header_cookies=True, cookies_field=True), 'all cookies fields')
test(make_info(cookies_field=True), 'cookies format field')
test({'url': TEST_URL, 'cookies': COOKIES, 'id': '1', 'title': 'x'}, 'info_dict cookies field only')
try_rm(TEST_FILE)
def test_add_headers_cookie(self):
def check_for_cookie_header(result):
return traverse_obj(result, ((None, ('formats', 0)), 'http_headers', 'Cookie'), casesense=False, get_all=False)
ydl = FakeYDL({'http_headers': {'Cookie': 'a=b'}})
ydl._apply_header_cookies(_make_result([])['webpage_url']) # Scope to input webpage URL: .example.com
fmt = {'url': 'https://example.com/video.mp4'}
result = ydl.process_ie_result(_make_result([fmt]), download=False)
self.assertIsNone(check_for_cookie_header(result), msg='http_headers cookies in result info_dict')
self.assertEqual(result.get('cookies'), 'a=b; Domain=.example.com', msg='No cookies were set in cookies field')
self.assertIn('a=b', ydl.cookiejar.get_cookie_header(fmt['url']), msg='No cookies were set in cookiejar')
fmt = {'url': 'https://wrong.com/video.mp4'}
result = ydl.process_ie_result(_make_result([fmt]), download=False)
self.assertIsNone(check_for_cookie_header(result), msg='http_headers cookies for wrong domain')
self.assertFalse(result.get('cookies'), msg='Cookies set in cookies field for wrong domain')
self.assertFalse(ydl.cookiejar.get_cookie_header(fmt['url']), msg='Cookies set in cookiejar for wrong domain')
if __name__ == '__main__':
unittest.main()

View File

@ -46,6 +46,20 @@ class TestYoutubeDLCookieJar(unittest.TestCase):
# will be ignored
self.assertFalse(cookiejar._cookies)
def test_get_cookie_header(self):
cookiejar = YoutubeDLCookieJar('./test/testdata/cookies/httponly_cookies.txt')
cookiejar.load(ignore_discard=True, ignore_expires=True)
header = cookiejar.get_cookie_header('https://www.foobar.foobar')
self.assertIn('HTTPONLY_COOKIE', header)
def test_get_cookies_for_url(self):
cookiejar = YoutubeDLCookieJar('./test/testdata/cookies/session_cookies.txt')
cookiejar.load(ignore_discard=True, ignore_expires=True)
cookies = cookiejar.get_cookies_for_url('https://www.foobar.foobar/')
self.assertEqual(len(cookies), 2)
cookies = cookiejar.get_cookies_for_url('https://foobar.foobar/')
self.assertFalse(cookies)
if __name__ == '__main__':
unittest.main()

View File

@ -12,20 +12,65 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from test.helper import (
FakeLogger,
FakeYDL,
http_server_port,
try_rm,
)
from youtube_dl import YoutubeDL
from youtube_dl.compat import compat_http_server
from youtube_dl.utils import encodeFilename
from youtube_dl.downloader.external import Aria2pFD
from youtube_dl.compat import (
compat_http_cookiejar_Cookie,
compat_http_server,
compat_kwargs,
)
from youtube_dl.utils import (
encodeFilename,
join_nonempty,
)
from youtube_dl.downloader.external import (
Aria2cFD,
Aria2pFD,
AxelFD,
CurlFD,
FFmpegFD,
HttpieFD,
WgetFD,
)
import threading
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_SIZE = 10 * 1024
TEST_COOKIE = {
'version': 0,
'name': 'test',
'value': 'ytdlp',
'port': None,
'port_specified': False,
'domain': '.example.com',
'domain_specified': True,
'domain_initial_dot': False,
'path': '/',
'path_specified': True,
'secure': False,
'expires': None,
'discard': False,
'comment': None,
'comment_url': None,
'rest': {},
}
TEST_COOKIE_VALUE = join_nonempty('name', 'value', delim='=', from_dict=TEST_COOKIE)
TEST_INFO = {'url': 'http://www.example.com/'}
def cookiejar_Cookie(**cookie_args):
return compat_http_cookiejar_Cookie(**compat_kwargs(cookie_args))
def ifExternalFDAvailable(externalFD):
return unittest.skipUnless(externalFD.available(),
externalFD.get_basename() + ' not found')
class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler):
def log_message(self, format, *args):
@ -70,7 +115,7 @@ class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler):
assert False, 'unrecognised server path'
@unittest.skipUnless(Aria2pFD.available(), 'aria2p module not found')
@ifExternalFDAvailable(Aria2pFD)
class TestAria2pFD(unittest.TestCase):
def setUp(self):
self.httpd = compat_http_server.HTTPServer(
@ -111,5 +156,103 @@ class TestAria2pFD(unittest.TestCase):
})
@ifExternalFDAvailable(HttpieFD)
class TestHttpieFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = HttpieFD(ydl, {})
self.assertEqual(
downloader._make_cmd('test', TEST_INFO),
['http', '--download', '--output', 'test', 'http://www.example.com/'])
# Test cookie header is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
self.assertEqual(
downloader._make_cmd('test', TEST_INFO),
['http', '--download', '--output', 'test',
'http://www.example.com/', 'Cookie:' + TEST_COOKIE_VALUE])
@ifExternalFDAvailable(AxelFD)
class TestAxelFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = AxelFD(ydl, {})
self.assertEqual(
downloader._make_cmd('test', TEST_INFO),
['axel', '-o', 'test', '--', 'http://www.example.com/'])
# Test cookie header is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
self.assertEqual(
downloader._make_cmd('test', TEST_INFO),
['axel', '-o', 'test', '-H', 'Cookie: ' + TEST_COOKIE_VALUE,
'--max-redirect=0', '--', 'http://www.example.com/'])
@ifExternalFDAvailable(WgetFD)
class TestWgetFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = WgetFD(ydl, {})
self.assertNotIn('--load-cookies', downloader._make_cmd('test', TEST_INFO))
# Test cookiejar tempfile arg is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
self.assertIn('--load-cookies', downloader._make_cmd('test', TEST_INFO))
@ifExternalFDAvailable(CurlFD)
class TestCurlFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = CurlFD(ydl, {})
self.assertNotIn('--cookie', downloader._make_cmd('test', TEST_INFO))
# Test cookie header is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
self.assertIn('--cookie', downloader._make_cmd('test', TEST_INFO))
self.assertIn(TEST_COOKIE_VALUE, downloader._make_cmd('test', TEST_INFO))
@ifExternalFDAvailable(Aria2cFD)
class TestAria2cFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = Aria2cFD(ydl, {})
downloader._make_cmd('test', TEST_INFO)
self.assertFalse(hasattr(downloader, '_cookies_tempfile'))
# Test cookiejar tempfile arg is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
cmd = downloader._make_cmd('test', TEST_INFO)
self.assertIn('--load-cookies=%s' % downloader._cookies_tempfile, cmd)
@ifExternalFDAvailable(FFmpegFD)
class TestFFmpegFD(unittest.TestCase):
_args = []
def _test_cmd(self, args):
self._args = args
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = FFmpegFD(ydl, {})
downloader._debug_cmd = self._test_cmd
info_dict = TEST_INFO.copy()
info_dict['ext'] = 'mp4'
downloader._call_downloader('test', info_dict)
self.assertEqual(self._args, [
'ffmpeg', '-y', '-i', 'http://www.example.com/',
'-c', 'copy', '-f', 'mp4', 'file:test'])
# Test cookies arg is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
downloader._call_downloader('test', info_dict)
self.assertEqual(self._args, [
'ffmpeg', '-y', '-cookies', TEST_COOKIE_VALUE + '; path=/; domain=.example.com;\r\n',
'-i', 'http://www.example.com/', '-c', 'copy', '-f', 'mp4', 'file:test'])
if __name__ == '__main__':
unittest.main()

View File

@ -24,21 +24,24 @@ except AttributeError:
class TestExecution(unittest.TestCase):
def setUp(self):
self.module = 'youtube_dl'
if sys.version_info < (2, 7):
self.module += '.__main__'
def test_import(self):
subprocess.check_call([sys.executable, '-c', 'import youtube_dl'], cwd=rootDir)
@unittest.skipIf(sys.version_info < (2, 7), 'Python 2.6 doesn\'t support package execution')
def test_module_exec(self):
subprocess.check_call([sys.executable, '-m', 'youtube_dl', '--version'], cwd=rootDir, stdout=_DEV_NULL)
subprocess.check_call([sys.executable, '-m', self.module, '--version'], cwd=rootDir, stdout=_DEV_NULL)
def test_main_exec(self):
subprocess.check_call([sys.executable, os.path.normpath('youtube_dl/__main__.py'), '--version'], cwd=rootDir, stdout=_DEV_NULL)
@unittest.skipIf(sys.version_info < (2, 7), 'Python 2.6 doesn\'t support package execution')
def test_cmdline_umlauts(self):
os.environ['PYTHONIOENCODING'] = 'utf-8'
p = subprocess.Popen(
[sys.executable, os.path.normpath('youtube_dl/__main__.py'), encodeArgument('ä'), '--version'],
[sys.executable, '-m', self.module, encodeArgument('ä'), '--version'],
cwd=rootDir, stdout=_DEV_NULL, stderr=subprocess.PIPE)
_, stderr = p.communicate()
self.assertFalse(stderr)

View File

@ -8,33 +8,161 @@ import sys
import unittest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import contextlib
import gzip
import io
import ssl
import tempfile
import threading
import zlib
# avoid deprecated alias assertRaisesRegexp
if hasattr(unittest.TestCase, 'assertRaisesRegex'):
unittest.TestCase.assertRaisesRegexp = unittest.TestCase.assertRaisesRegex
try:
import brotli
except ImportError:
brotli = None
try:
from urllib.request import pathname2url
except ImportError:
from urllib import pathname2url
from youtube_dl.compat import (
compat_http_cookiejar_Cookie,
compat_http_server,
compat_str as str,
compat_urllib_error,
compat_urllib_HTTPError,
compat_urllib_parse,
compat_urllib_request,
)
from youtube_dl.utils import (
sanitized_Request,
urlencode_postdata,
)
from test.helper import (
FakeYDL,
FakeLogger,
http_server_port,
)
from youtube_dl import YoutubeDL
from youtube_dl.compat import compat_http_server, compat_urllib_request
import ssl
import threading
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
# work-around old/new -style class inheritance
def super(self, meth_name, *args, **kwargs):
from types import MethodType
try:
super()
fn = lambda s, m, *a, **k: getattr(super(), m)(*a, **k)
except TypeError:
fn = lambda s, m, *a, **k: getattr(compat_http_server.BaseHTTPRequestHandler, m)(s, *a, **k)
self.super = MethodType(fn, self)
return self.super(meth_name, *args, **kwargs)
def log_message(self, format, *args):
pass
def do_GET(self):
if self.path == '/video.html':
def _headers(self):
payload = str(self.headers).encode('utf-8')
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def _redirect(self):
self.send_response(int(self.path[len('/redirect_'):]))
self.send_header('Location', '/method')
self.send_header('Content-Length', '0')
self.end_headers()
def _method(self, method, payload=None):
self.send_response(200)
self.send_header('Content-Length', str(len(payload or '')))
self.send_header('Method', method)
self.end_headers()
if payload:
self.wfile.write(payload)
def _status(self, status):
payload = '<html>{0} NOT FOUND</html>'.format(status).encode('utf-8')
self.send_response(int(status))
self.send_header('Content-Type', 'text/html; charset=utf-8')
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(b'<html><video src="/vid.mp4" /></html>')
self.wfile.write(payload)
def _read_data(self):
if 'Content-Length' in self.headers:
return self.rfile.read(int(self.headers['Content-Length']))
def _test_url(self, path, host='127.0.0.1', scheme='http', port=None):
return '{0}://{1}:{2}/{3}'.format(
scheme, host,
port if port is not None
else http_server_port(self.server), path)
def do_POST(self):
data = self._read_data()
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('POST', data)
elif self.path.startswith('/headers'):
self._headers()
else:
self._status(404)
def do_HEAD(self):
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('HEAD')
else:
self._status(404)
def do_PUT(self):
data = self._read_data()
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('PUT', data)
else:
self._status(404)
def do_GET(self):
def respond(payload=b'<html><video src="/vid.mp4" /></html>',
payload_type='text/html; charset=utf-8',
payload_encoding=None,
resp_code=200):
self.send_response(resp_code)
self.send_header('Content-Type', payload_type)
if payload_encoding:
self.send_header('Content-Encoding', payload_encoding)
self.send_header('Content-Length', str(len(payload))) # required for persistent connections
self.end_headers()
self.wfile.write(payload)
def gzip_compress(p):
buf = io.BytesIO()
with contextlib.closing(gzip.GzipFile(fileobj=buf, mode='wb')) as f:
f.write(p)
return buf.getvalue()
if self.path == '/video.html':
respond()
elif self.path == '/vid.mp4':
self.send_response(200)
self.send_header('Content-Type', 'video/mp4')
self.end_headers()
self.wfile.write(b'\x00\x00\x00\x00\x20\x66\x74[video]')
respond(b'\x00\x00\x00\x00\x20\x66\x74[video]', 'video/mp4')
elif self.path == '/302':
if sys.version_info[0] == 3:
# XXX: Python 3 http server does not allow non-ASCII header values
@ -42,60 +170,316 @@ class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler):
self.end_headers()
return
new_url = 'http://127.0.0.1:%d/中文.html' % http_server_port(self.server)
new_url = self._test_url('中文.html')
self.send_response(302)
self.send_header(b'Location', new_url.encode('utf-8'))
self.end_headers()
elif self.path == '/%E4%B8%AD%E6%96%87.html':
self.send_response(200)
self.send_header('Content-Type', 'text/html; charset=utf-8')
respond()
elif self.path == '/%c7%9f':
respond()
elif self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('GET')
elif self.path.startswith('/headers'):
self._headers()
elif self.path.startswith('/308-to-headers'):
self.send_response(308)
self.send_header('Location', '/headers')
self.send_header('Content-Length', '0')
self.end_headers()
self.wfile.write(b'<html><video src="/vid.mp4" /></html>')
elif self.path == '/trailing_garbage':
payload = b'<html><video src="/vid.mp4" /></html>'
compressed = gzip_compress(payload) + b'trailing garbage'
respond(compressed, payload_encoding='gzip')
elif self.path == '/302-non-ascii-redirect':
new_url = self._test_url('中文.html')
# actually respond with permanent redirect
self.send_response(301)
self.send_header('Location', new_url)
self.send_header('Content-Length', '0')
self.end_headers()
elif self.path == '/content-encoding':
encodings = self.headers.get('ytdl-encoding', '')
payload = b'<html><video src="/vid.mp4" /></html>'
for encoding in filter(None, (e.strip() for e in encodings.split(','))):
if encoding == 'br' and brotli:
payload = brotli.compress(payload)
elif encoding == 'gzip':
payload = gzip_compress(payload)
elif encoding == 'deflate':
payload = zlib.compress(payload)
elif encoding == 'unsupported':
payload = b'raw'
break
else:
assert False
self._status(415)
return
respond(payload, payload_encoding=encodings)
else:
self._status(404)
def send_header(self, keyword, value):
"""
Forcibly allow HTTP server to send non percent-encoded non-ASCII characters in headers.
This is against what is defined in RFC 3986: but we need to test that we support this
since some sites incorrectly do this.
"""
if keyword.lower() == 'connection':
return self.super('send_header', keyword, value)
if not hasattr(self, '_headers_buffer'):
self._headers_buffer = []
self._headers_buffer.append('{0}: {1}\r\n'.format(keyword, value).encode('utf-8'))
def end_headers(self):
if hasattr(self, '_headers_buffer'):
self.wfile.write(b''.join(self._headers_buffer))
self._headers_buffer = []
self.super('end_headers')
class TestHTTP(unittest.TestCase):
def setUp(self):
self.httpd = compat_http_server.HTTPServer(
# HTTP server
self.http_httpd = compat_http_server.HTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
self.port = http_server_port(self.httpd)
self.server_thread = threading.Thread(target=self.httpd.serve_forever)
self.server_thread.daemon = True
self.server_thread.start()
self.http_port = http_server_port(self.http_httpd)
self.http_server_thread = threading.Thread(target=self.http_httpd.serve_forever)
self.http_server_thread.daemon = True
self.http_server_thread.start()
try:
from http.server import ThreadingHTTPServer
except ImportError:
try:
from socketserver import ThreadingMixIn
except ImportError:
from SocketServer import ThreadingMixIn
class ThreadingHTTPServer(ThreadingMixIn, compat_http_server.HTTPServer):
pass
# HTTPS server
certfn = os.path.join(TEST_DIR, 'testcert.pem')
self.https_httpd = ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
try:
sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslctx.verify_mode = ssl.CERT_NONE
sslctx.check_hostname = False
sslctx.load_cert_chain(certfn, None)
self.https_httpd.socket = sslctx.wrap_socket(
self.https_httpd.socket, server_side=True)
except AttributeError:
self.https_httpd.socket = ssl.wrap_socket(
self.https_httpd.socket, certfile=certfn, server_side=True)
self.https_port = http_server_port(self.https_httpd)
self.https_server_thread = threading.Thread(target=self.https_httpd.serve_forever)
self.https_server_thread.daemon = True
self.https_server_thread.start()
def tearDown(self):
def closer(svr):
def _closer():
svr.shutdown()
svr.server_close()
return _closer
shutdown_thread = threading.Thread(target=closer(self.http_httpd))
shutdown_thread.start()
self.http_server_thread.join(2.0)
shutdown_thread = threading.Thread(target=closer(self.https_httpd))
shutdown_thread.start()
self.https_server_thread.join(2.0)
def _test_url(self, path, host='127.0.0.1', scheme='http', port=None):
return '{0}://{1}:{2}/{3}'.format(
scheme, host,
port if port is not None
else self.https_port if scheme == 'https'
else self.http_port, path)
@unittest.skipUnless(
sys.version_info >= (3, 2)
or (sys.version_info[0] == 2 and sys.version_info[1:] >= (7, 9)),
'No support for certificate check in SSL')
def test_nocheckcertificate(self):
with FakeYDL({'logger': FakeLogger()}) as ydl:
with self.assertRaises(compat_urllib_error.URLError):
ydl.urlopen(sanitized_Request(self._test_url('headers', scheme='https')))
with FakeYDL({'logger': FakeLogger(), 'nocheckcertificate': True}) as ydl:
r = ydl.urlopen(sanitized_Request(self._test_url('headers', scheme='https')))
self.assertEqual(r.getcode(), 200)
r.close()
def test_percent_encode(self):
with FakeYDL() as ydl:
# Unicode characters should be encoded with uppercase percent-encoding
res = ydl.urlopen(sanitized_Request(self._test_url('中文.html')))
self.assertEqual(res.getcode(), 200)
res.close()
# don't normalize existing percent encodings
res = ydl.urlopen(sanitized_Request(self._test_url('%c7%9f')))
self.assertEqual(res.getcode(), 200)
res.close()
def test_unicode_path_redirection(self):
# XXX: Python 3 http server does not allow non-ASCII header values
if sys.version_info[0] == 3:
return
with FakeYDL() as ydl:
r = ydl.urlopen(sanitized_Request(self._test_url('302-non-ascii-redirect')))
self.assertEqual(r.url, self._test_url('%E4%B8%AD%E6%96%87.html'))
r.close()
ydl = YoutubeDL({'logger': FakeLogger()})
r = ydl.extract_info('http://127.0.0.1:%d/302' % self.port)
self.assertEqual(r['entries'][0]['url'], 'http://127.0.0.1:%d/vid.mp4' % self.port)
def test_redirect(self):
with FakeYDL() as ydl:
def do_req(redirect_status, method, check_no_content=False):
data = b'testdata' if method in ('POST', 'PUT') else None
res = ydl.urlopen(sanitized_Request(
self._test_url('redirect_{0}'.format(redirect_status)),
method=method, data=data))
if check_no_content:
self.assertNotIn('Content-Type', res.headers)
return res.read().decode('utf-8'), res.headers.get('method', '')
# A 303 must either use GET or HEAD for subsequent request
self.assertEqual(do_req(303, 'POST'), ('', 'GET'))
self.assertEqual(do_req(303, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(303, 'PUT'), ('', 'GET'))
class TestHTTPS(unittest.TestCase):
def setUp(self):
certfn = os.path.join(TEST_DIR, 'testcert.pem')
self.httpd = compat_http_server.HTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
self.httpd.socket = ssl.wrap_socket(
self.httpd.socket, certfile=certfn, server_side=True)
self.port = http_server_port(self.httpd)
self.server_thread = threading.Thread(target=self.httpd.serve_forever)
self.server_thread.daemon = True
self.server_thread.start()
# 301 and 302 turn POST only into a GET, with no Content-Type
self.assertEqual(do_req(301, 'POST', True), ('', 'GET'))
self.assertEqual(do_req(301, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(302, 'POST', True), ('', 'GET'))
self.assertEqual(do_req(302, 'HEAD'), ('', 'HEAD'))
def test_nocheckcertificate(self):
if sys.version_info >= (2, 7, 9): # No certificate checking anyways
ydl = YoutubeDL({'logger': FakeLogger()})
self.assertRaises(
Exception,
ydl.extract_info, 'https://127.0.0.1:%d/video.html' % self.port)
self.assertEqual(do_req(301, 'PUT'), ('testdata', 'PUT'))
self.assertEqual(do_req(302, 'PUT'), ('testdata', 'PUT'))
ydl = YoutubeDL({'logger': FakeLogger(), 'nocheckcertificate': True})
r = ydl.extract_info('https://127.0.0.1:%d/video.html' % self.port)
self.assertEqual(r['entries'][0]['url'], 'https://127.0.0.1:%d/vid.mp4' % self.port)
# 307 and 308 should not change method
for m in ('POST', 'PUT'):
self.assertEqual(do_req(307, m), ('testdata', m))
self.assertEqual(do_req(308, m), ('testdata', m))
self.assertEqual(do_req(307, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(308, 'HEAD'), ('', 'HEAD'))
# These should not redirect and instead raise an HTTPError
for code in (300, 304, 305, 306):
with self.assertRaises(compat_urllib_HTTPError):
do_req(code, 'GET')
def test_content_type(self):
# https://github.com/yt-dlp/yt-dlp/commit/379a4f161d4ad3e40932dcf5aca6e6fb9715ab28
with FakeYDL({'nocheckcertificate': True}) as ydl:
# method should be auto-detected as POST
r = sanitized_Request(self._test_url('headers', scheme='https'), data=urlencode_postdata({'test': 'test'}))
headers = ydl.urlopen(r).read().decode('utf-8')
self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
# test http
r = sanitized_Request(self._test_url('headers'), data=urlencode_postdata({'test': 'test'}))
headers = ydl.urlopen(r).read().decode('utf-8')
self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
def test_cookiejar(self):
with FakeYDL() as ydl:
ydl.cookiejar.set_cookie(compat_http_cookiejar_Cookie(
0, 'test', 'ytdl', None, False, '127.0.0.1', True,
False, '/headers', True, False, None, False, None, None, {}))
data = ydl.urlopen(sanitized_Request(
self._test_url('headers'))).read().decode('utf-8')
self.assertIn('Cookie: test=ytdl', data)
def test_passed_cookie_header(self):
# We should accept a Cookie header being passed as in normal headers and handle it appropriately.
with FakeYDL() as ydl:
# Specified Cookie header should be used
res = ydl.urlopen(sanitized_Request(
self._test_url('headers'), headers={'Cookie': 'test=test'})).read().decode('utf-8')
self.assertIn('Cookie: test=test', res)
# Specified Cookie header should be removed on any redirect
res = ydl.urlopen(sanitized_Request(
self._test_url('308-to-headers'), headers={'Cookie': 'test=test'})).read().decode('utf-8')
self.assertNotIn('Cookie: test=test', res)
# Specified Cookie header should override global cookiejar for that request
ydl.cookiejar.set_cookie(compat_http_cookiejar_Cookie(
0, 'test', 'ytdlp', None, False, '127.0.0.1', True,
False, '/headers', True, False, None, False, None, None, {}))
data = ydl.urlopen(sanitized_Request(
self._test_url('headers'), headers={'Cookie': 'test=test'})).read().decode('utf-8')
self.assertNotIn('Cookie: test=ytdlp', data)
self.assertIn('Cookie: test=test', data)
def test_no_compression_compat_header(self):
with FakeYDL() as ydl:
data = ydl.urlopen(
sanitized_Request(
self._test_url('headers'),
headers={'Youtubedl-no-compression': True})).read()
self.assertIn(b'Accept-Encoding: identity', data)
self.assertNotIn(b'youtubedl-no-compression', data.lower())
def test_gzip_trailing_garbage(self):
# https://github.com/ytdl-org/youtube-dl/commit/aa3e950764337ef9800c936f4de89b31c00dfcf5
# https://github.com/ytdl-org/youtube-dl/commit/6f2ec15cee79d35dba065677cad9da7491ec6e6f
with FakeYDL() as ydl:
data = ydl.urlopen(sanitized_Request(self._test_url('trailing_garbage'))).read().decode('utf-8')
self.assertEqual(data, '<html><video src="/vid.mp4" /></html>')
def __test_compression(self, encoding):
with FakeYDL() as ydl:
res = ydl.urlopen(
sanitized_Request(
self._test_url('content-encoding'),
headers={'ytdl-encoding': encoding}))
self.assertEqual(res.headers.get('Content-Encoding'), encoding)
self.assertEqual(res.read(), b'<html><video src="/vid.mp4" /></html>')
@unittest.skipUnless(brotli, 'brotli support is not installed')
@unittest.expectedFailure
def test_brotli(self):
self.__test_compression('br')
@unittest.expectedFailure
def test_deflate(self):
self.__test_compression('deflate')
@unittest.expectedFailure
def test_gzip(self):
self.__test_compression('gzip')
@unittest.expectedFailure # not yet implemented
def test_multiple_encodings(self):
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.4
with FakeYDL() as ydl:
for pair in ('gzip,deflate', 'deflate, gzip', 'gzip, gzip', 'deflate, deflate'):
res = ydl.urlopen(
sanitized_Request(
self._test_url('content-encoding'),
headers={'ytdl-encoding': pair}))
self.assertEqual(res.headers.get('Content-Encoding'), pair)
self.assertEqual(res.read(), b'<html><video src="/vid.mp4" /></html>')
def test_unsupported_encoding(self):
# it should return the raw content
with FakeYDL() as ydl:
res = ydl.urlopen(
sanitized_Request(
self._test_url('content-encoding'),
headers={'ytdl-encoding': 'unsupported'}))
self.assertEqual(res.headers.get('Content-Encoding'), 'unsupported')
self.assertEqual(res.read(), b'raw')
def _build_proxy_handler(name):
@ -109,7 +493,7 @@ def _build_proxy_handler(name):
self.send_response(200)
self.send_header('Content-Type', 'text/plain; charset=utf-8')
self.end_headers()
self.wfile.write('{self.proxy_name}: {self.path}'.format(self=self).encode('utf-8'))
self.wfile.write('{0}: {1}'.format(self.proxy_name, self.path).encode('utf-8'))
return HTTPTestRequestHandler
@ -129,10 +513,30 @@ class TestProxy(unittest.TestCase):
self.geo_proxy_thread.daemon = True
self.geo_proxy_thread.start()
def tearDown(self):
def closer(svr):
def _closer():
svr.shutdown()
svr.server_close()
return _closer
shutdown_thread = threading.Thread(target=closer(self.proxy))
shutdown_thread.start()
self.proxy_thread.join(2.0)
shutdown_thread = threading.Thread(target=closer(self.geo_proxy))
shutdown_thread.start()
self.geo_proxy_thread.join(2.0)
def _test_proxy(self, host='127.0.0.1', port=None):
return '{0}:{1}'.format(
host, port if port is not None else self.port)
def test_proxy(self):
geo_proxy = '127.0.0.1:{0}'.format(self.geo_port)
geo_proxy = self._test_proxy(port=self.geo_port)
ydl = YoutubeDL({
'proxy': '127.0.0.1:{0}'.format(self.port),
'proxy': self._test_proxy(),
'geo_verification_proxy': geo_proxy,
})
url = 'http://foo.com/bar'
@ -146,7 +550,7 @@ class TestProxy(unittest.TestCase):
def test_proxy_with_idn(self):
ydl = YoutubeDL({
'proxy': '127.0.0.1:{0}'.format(self.port),
'proxy': self._test_proxy(),
})
url = 'http://中文.tw/'
response = ydl.urlopen(url).read().decode('utf-8')
@ -154,5 +558,25 @@ class TestProxy(unittest.TestCase):
self.assertEqual(response, 'normal: http://xn--fiq228c.tw/')
class TestFileURL(unittest.TestCase):
# See https://github.com/ytdl-org/youtube-dl/issues/8227
def test_file_urls(self):
tf = tempfile.NamedTemporaryFile(delete=False)
tf.write(b'foobar')
tf.close()
url = compat_urllib_parse.urljoin('file://', pathname2url(tf.name))
with FakeYDL() as ydl:
self.assertRaisesRegexp(
compat_urllib_error.URLError, 'file:// scheme is explicitly disabled in youtube-dl for security reasons', ydl.urlopen, url)
# not yet implemented
"""
with FakeYDL({'enable_file_urls': True}) as ydl:
res = ydl.urlopen(url)
self.assertEqual(res.read(), b'foobar')
res.close()
"""
os.unlink(tf.name)
if __name__ == '__main__':
unittest.main()

View File

@ -79,10 +79,12 @@ from youtube_dl.utils import (
rot47,
shell_quote,
smuggle_url,
str_or_none,
str_to_int,
strip_jsonp,
strip_or_none,
subtitles_filename,
T,
timeconvert,
traverse_obj,
try_call,
@ -1566,6 +1568,7 @@ Line 1
self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam')
def test_traverse_obj(self):
str = compat_str
_TEST_DATA = {
100: 100,
1.2: 1.2,
@ -1598,8 +1601,8 @@ Line 1
# Test Ellipsis behavior
self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
(item for item in _TEST_DATA.values() if item is not None),
msg='`...` should give all values except `None`')
(item for item in _TEST_DATA.values() if item not in (None, {})),
msg='`...` should give all non discarded values')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
msg='`...` selection for dicts should select all values')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
@ -1607,13 +1610,51 @@ Line 1
msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4),
msg='`...` query result should be flattened')
self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),
msg='`...` should accept iterables')
# Test function as key
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
[_TEST_DATA['urls']],
msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], compat_str)), ('str',),
msg='exceptions in the query function should be caught')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)),
msg='exceptions in the query function should be catched')
self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
msg='function key should accept iterables')
if __debug__:
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a: Ellipsis)
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a, b, c: Ellipsis)
# Test set as key (transformation/type, like `expected_type`)
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper), )), ['STR'],
msg='Function in set should be a transformation')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str))), ['str'],
msg='Type in set should be a type filter')
self.assertEqual(traverse_obj(_TEST_DATA, T(dict)), _TEST_DATA,
msg='A single set should be wrapped into a path')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper))), ['STR'],
msg='Transformation function should not raise')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))),
[item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
msg='Function in set should be a transformation')
if __debug__:
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, set())
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, set((str.upper, str)))
# Test `slice` as a key
_SLICE_DATA = [0, 1, 2, 3, 4]
self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None,
msg='slice on a dictionary should not throw')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2],
msg='slice key should apply slice to sequence')
# Test alternative paths
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
@ -1659,15 +1700,23 @@ Line 1
{0: ['https://www.example.com/1', 'https://www.example.com/0']},
msg='triple nesting in dict path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
msg='remove `None` values when dict key')
msg='remove `None` values when top level dict key fails')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
msg='do not remove `None` values if `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}},
msg='do not remove empty values when dict key')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: {}},
msg='do not remove empty values when dict key and a default')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {0: []},
msg='if branch in dict key not successful, return `[]`')
msg='use `default` if key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
msg='remove empty values when dict key')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis},
msg='use `default` when dict key and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
msg='remove empty values when nested dict key fails')
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
msg='default to dict if pruned')
self.assertEqual(traverse_obj(None, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
msg='default to dict if pruned and default is given')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=Ellipsis), {0: {0: Ellipsis}},
msg='use nested `default` when nested dict key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {},
msg='remove key if branch in dict key not successful')
# Testing default parameter behavior
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
@ -1691,20 +1740,55 @@ Line 1
msg='if branched but not successful return `[]`, not `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [],
msg='if branched but object is empty return `[]`, not `default`')
self.assertEqual(traverse_obj(None, Ellipsis), [],
msg='if branched but object is `None` return `[]`, not `default`')
self.assertEqual(traverse_obj({0: None}, (0, Ellipsis)), [],
msg='if branched but state is `None` return `[]`, not `default`')
branching_paths = [
('fail', Ellipsis),
(Ellipsis, 'fail'),
100 * ('fail',) + (Ellipsis,),
(Ellipsis,) + 100 * ('fail',),
]
for branching_path in branching_paths:
self.assertEqual(traverse_obj({}, branching_path), [],
msg='if branched but state is `None`, return `[]` (not `default`)')
self.assertEqual(traverse_obj({}, 'fail', branching_path), [],
msg='if branching in last alternative and previous did not match, return `[]` (not `default`)')
self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x',
msg='if branching in last alternative and previous did match, return single value')
self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x',
msg='if branching in first alternative and non-branching path does match, return single value')
self.assertEqual(traverse_obj({}, branching_path, 'fail'), None,
msg='if branching in first alternative and non-branching path does not match, return `default`')
# Testing expected_type behavior
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=compat_str), 'str',
msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None,
msg='reject non matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: compat_str(x)), '0',
msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str',
expected_type=lambda _: 1 / 0), None,
msg='wrap expected_type function in try_call')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=compat_str), ['str'],
msg='eliminate items that expected_type fails on')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
'str', msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
None, msg='reject non matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
'0', msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
None, msg='wrap expected_type function in try_call')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=str),
['str'], msg='eliminate items that expected_type fails on')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int),
{0: 100}, msg='type as expected_type should filter dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
{0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int),
1, msg='expected_type should not filter non final dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
{0: {0: 100}}, msg='expected_type should transform deep dict values')
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)),
[{0: Ellipsis}, {0: Ellipsis}], msg='expected_type should transform branched dict values')
self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int),
[4], msg='expected_type regression for type matching in tuple branching')
self.assertEqual(traverse_obj(_TEST_DATA, ['data', Ellipsis], expected_type=int),
[], msg='expected_type regression for type matching in dict result')
# Test get_all behavior
_GET_ALL_DATA = {'key': [0, 1, 2]}
@ -1749,14 +1833,23 @@ Line 1
_traverse_string=True), '.',
msg='traverse into converted data if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis),
_traverse_string=True), list('str'),
msg='`...` branching into string should result in list')
_traverse_string=True), 'str',
msg='`...` should result in string (same value) if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
_traverse_string=True), 'sr',
msg='`slice` should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
_traverse_string=True), 'str',
msg='function should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
_traverse_string=True), ['s', 'r'],
msg='branching into string should result in list')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x),
_traverse_string=True), list('str'),
msg='function branching into string should result in list')
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, Ellipsis), _traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, lambda x, y: True), _traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, slice(1)), _traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
# Test is_user_input behavior
_IS_USER_INPUT_DATA = {'range8': list(range(8))}
@ -1793,6 +1886,8 @@ Line 1
msg='failing str key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, 8), None,
msg='failing int key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
msg='function on a `re.Match` should give group name as well')
def test_get_first(self):
self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')

View File

@ -5,6 +5,7 @@ from __future__ import absolute_import, unicode_literals
import collections
import contextlib
import copy
import datetime
import errno
import fileinput
@ -34,10 +35,12 @@ from string import ascii_letters
from .compat import (
compat_basestring,
compat_cookiejar,
compat_collections_chain_map as ChainMap,
compat_filter as filter,
compat_get_terminal_size,
compat_http_client,
compat_http_cookiejar_Cookie,
compat_http_cookies_SimpleCookie,
compat_integer_types,
compat_kwargs,
compat_map as map,
@ -53,6 +56,7 @@ from .compat import (
from .utils import (
age_restricted,
args_to_str,
bug_reports_message,
ContentTooShortError,
date_from_str,
DateRange,
@ -97,6 +101,7 @@ from .utils import (
std_headers,
str_or_none,
subtitles_filename,
traverse_obj,
UnavailableVideoError,
url_basename,
version_tuple,
@ -376,6 +381,9 @@ class YoutubeDL(object):
self.params.update(params)
self.cache = Cache(self)
self._header_cookies = []
self._load_cookies_from_headers(self.params.get('http_headers'))
def check_deprecated(param, option, suggestion):
if self.params.get(param) is not None:
self.report_warning(
@ -582,7 +590,7 @@ class YoutubeDL(object):
if self.params.get('cookiefile') is not None:
self.cookiejar.save(ignore_discard=True, ignore_expires=True)
def trouble(self, message=None, tb=None):
def trouble(self, *args, **kwargs):
"""Determine action to take when a download problem appears.
Depending on if the downloader has been configured to ignore
@ -591,6 +599,11 @@ class YoutubeDL(object):
tb, if given, is additional traceback information.
"""
# message=None, tb=None, is_error=True
message = args[0] if len(args) > 0 else kwargs.get('message', None)
tb = args[1] if len(args) > 1 else kwargs.get('tb', None)
is_error = args[2] if len(args) > 2 else kwargs.get('is_error', True)
if message is not None:
self.to_stderr(message)
if self.params.get('verbose'):
@ -603,7 +616,10 @@ class YoutubeDL(object):
else:
tb_data = traceback.format_list(traceback.extract_stack())
tb = ''.join(tb_data)
if tb:
self.to_stderr(tb)
if not is_error:
return
if not self.params.get('ignoreerrors', False):
if sys.exc_info()[0] and hasattr(sys.exc_info()[1], 'exc_info') and sys.exc_info()[1].exc_info[0]:
exc_info = sys.exc_info()[1].exc_info
@ -612,11 +628,18 @@ class YoutubeDL(object):
raise DownloadError(message, exc_info)
self._download_retcode = 1
def report_warning(self, message):
def report_warning(self, message, only_once=False, _cache={}):
'''
Print the message to stderr, it will be prefixed with 'WARNING:'
If stderr is a tty file the 'WARNING:' will be colored
'''
if only_once:
m_hash = hash((self, message))
m_cnt = _cache.setdefault(m_hash, 0)
_cache[m_hash] = m_cnt + 1
if m_cnt > 0:
return
if self.params.get('logger') is not None:
self.params['logger'].warning(message)
else:
@ -629,7 +652,7 @@ class YoutubeDL(object):
warning_message = '%s %s' % (_msg_header, message)
self.to_stderr(warning_message)
def report_error(self, message, tb=None):
def report_error(self, message, *args, **kwargs):
'''
Do the same as trouble, but prefixes the message with 'ERROR:', colored
in red if stderr is a tty file.
@ -638,8 +661,18 @@ class YoutubeDL(object):
_msg_header = '\033[0;31mERROR:\033[0m'
else:
_msg_header = 'ERROR:'
error_message = '%s %s' % (_msg_header, message)
self.trouble(error_message, tb)
kwargs['message'] = '%s %s' % (_msg_header, message)
self.trouble(*args, **kwargs)
def report_unscoped_cookies(self, *args, **kwargs):
# message=None, tb=False, is_error=False
if len(args) <= 2:
kwargs.setdefault('is_error', False)
if len(args) <= 0:
kwargs.setdefault(
'message',
'Unscoped cookies are not allowed: please specify some sort of scoping')
self.report_error(*args, **kwargs)
def report_file_already_downloaded(self, file_name):
"""Report file has already been fully downloaded."""
@ -835,7 +868,7 @@ class YoutubeDL(object):
msg += '\nYou might want to use a VPN or a proxy server (with --proxy) to workaround.'
self.report_error(msg)
except ExtractorError as e: # An error we somewhat expected
self.report_error(compat_str(e), e.format_traceback())
self.report_error(compat_str(e), tb=e.format_traceback())
except MaxDownloadsReached:
raise
except Exception as e:
@ -845,8 +878,83 @@ class YoutubeDL(object):
raise
return wrapper
def _remove_cookie_header(self, http_headers):
"""Filters out `Cookie` header from an `http_headers` dict
The `Cookie` header is removed to prevent leaks as a result of unscoped cookies.
See: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-v8mc-9377-rwjj
@param http_headers An `http_headers` dict from which any `Cookie` header
should be removed, or None
"""
return dict(filter(lambda pair: pair[0].lower() != 'cookie', (http_headers or {}).items()))
def _load_cookies(self, data, **kwargs):
"""Loads cookies from a `Cookie` header
This tries to work around the security vulnerability of passing cookies to every domain.
@param data The Cookie header as a string to load the cookies from
@param autoscope If `False`, scope cookies using Set-Cookie syntax and error for cookie without domains
If `True`, save cookies for later to be stored in the jar with a limited scope
If a URL, save cookies in the jar with the domain of the URL
"""
# autoscope=True (kw-only)
autoscope = kwargs.get('autoscope', True)
for cookie in compat_http_cookies_SimpleCookie(data).values() if data else []:
if autoscope and any(cookie.values()):
raise ValueError('Invalid syntax in Cookie Header')
domain = cookie.get('domain') or ''
expiry = cookie.get('expires')
if expiry == '': # 0 is valid so we check for `''` explicitly
expiry = None
prepared_cookie = compat_http_cookiejar_Cookie(
cookie.get('version') or 0, cookie.key, cookie.value, None, False,
domain, True, True, cookie.get('path') or '', bool(cookie.get('path')),
bool(cookie.get('secure')), expiry, False, None, None, {})
if domain:
self.cookiejar.set_cookie(prepared_cookie)
elif autoscope is True:
self.report_warning(
'Passing cookies as a header is a potential security risk; '
'they will be scoped to the domain of the downloaded urls. '
'Please consider loading cookies from a file or browser instead.',
only_once=True)
self._header_cookies.append(prepared_cookie)
elif autoscope:
self.report_warning(
'The extractor result contains an unscoped cookie as an HTTP header. '
'If you are specifying an input URL, ' + bug_reports_message(),
only_once=True)
self._apply_header_cookies(autoscope, [prepared_cookie])
else:
self.report_unscoped_cookies()
def _load_cookies_from_headers(self, headers):
self._load_cookies(traverse_obj(headers, 'cookie', casesense=False))
def _apply_header_cookies(self, url, cookies=None):
"""This method applies stray header cookies to the provided url
This loads header cookies and scopes them to the domain provided in `url`.
While this is not ideal, it helps reduce the risk of them being sent to
an unintended destination.
"""
parsed = compat_urllib_parse.urlparse(url)
if not parsed.hostname:
return
for cookie in map(copy.copy, cookies or self._header_cookies):
cookie.domain = '.' + parsed.hostname
self.cookiejar.set_cookie(cookie)
@__handle_extraction_exceptions
def __extract_info(self, url, ie, download, extra_info, process):
# Compat with passing cookies in http headers
self._apply_header_cookies(url)
ie_result = ie.extract(url)
if ie_result is None: # Finished already (backwards compatibility; listformats and friends should be moved here)
return
@ -1443,23 +1551,45 @@ class YoutubeDL(object):
parsed_selector = _parse_format_selection(iter(TokenIterator(tokens)))
return _build_selector_function(parsed_selector)
def _calc_headers(self, info_dict):
res = std_headers.copy()
def _calc_headers(self, info_dict, load_cookies=False):
if load_cookies: # For --load-info-json
# load cookies from http_headers in legacy info.json
self._load_cookies(traverse_obj(info_dict, ('http_headers', 'Cookie'), casesense=False),
autoscope=info_dict['url'])
# load scoped cookies from info.json
self._load_cookies(info_dict.get('cookies'), autoscope=False)
add_headers = info_dict.get('http_headers')
if add_headers:
res.update(add_headers)
cookies = self._calc_cookies(info_dict)
cookies = self.cookiejar.get_cookies_for_url(info_dict['url'])
if cookies:
res['Cookie'] = cookies
# Make a string like name1=val1; attr1=a_val1; ...name2=val2; ...
# By convention a cookie name can't be a well-known attribute name
# so this syntax is unambiguous and can be parsed by (eg) SimpleCookie
encoder = compat_http_cookies_SimpleCookie()
values = []
attributes = (('Domain', '='), ('Path', '='), ('Secure',), ('Expires', '='), ('Version', '='))
attributes = tuple([x[0].lower()] + list(x) for x in attributes)
for cookie in cookies:
_, value = encoder.value_encode(cookie.value)
# Py 2 '' --> '', Py 3 '' --> '""'
if value == '':
value = '""'
values.append('='.join((cookie.name, value)))
for attr in attributes:
value = getattr(cookie, attr[0], None)
if value:
values.append('%s%s' % (''.join(attr[1:]), value if len(attr) == 3 else ''))
info_dict['cookies'] = '; '.join(values)
res = std_headers.copy()
res.update(info_dict.get('http_headers') or {})
res = self._remove_cookie_header(res)
if 'X-Forwarded-For' not in res:
x_forwarded_for_ip = info_dict.get('__x_forwarded_for_ip')
if x_forwarded_for_ip:
res['X-Forwarded-For'] = x_forwarded_for_ip
return res
return res or None
def _calc_cookies(self, info_dict):
pr = sanitized_Request(info_dict['url'])
@ -1638,10 +1768,13 @@ class YoutubeDL(object):
format['protocol'] = determine_protocol(format)
# Add HTTP headers, so that external programs can use them from the
# json output
full_format_info = info_dict.copy()
full_format_info.update(format)
format['http_headers'] = self._calc_headers(full_format_info)
# Remove private housekeeping stuff
format['http_headers'] = self._calc_headers(ChainMap(format, info_dict), load_cookies=True)
# Safeguard against old/insecure infojson when using --load-info-json
info_dict['http_headers'] = self._remove_cookie_header(
info_dict.get('http_headers') or {}) or None
# Remove private housekeeping stuff (copied to http_headers in _calc_headers())
if '__x_forwarded_for_ip' in info_dict:
del info_dict['__x_forwarded_for_ip']
@ -1902,17 +2035,9 @@ class YoutubeDL(object):
(sub_lang, error_to_compat_str(err)))
continue
if self.params.get('writeinfojson', False):
infofn = replace_extension(filename, 'info.json', info_dict.get('ext'))
if self.params.get('nooverwrites', False) and os.path.exists(encodeFilename(infofn)):
self.to_screen('[info] Video description metadata is already present')
else:
self.to_screen('[info] Writing video description metadata as JSON to: ' + infofn)
try:
write_json_file(self.filter_requested_info(info_dict), infofn)
except (OSError, IOError):
self.report_error('Cannot write metadata to JSON file ' + infofn)
return
self._write_info_json(
'video description', info_dict,
replace_extension(filename, 'info.json', info_dict.get('ext')))
self._write_thumbnails(info_dict, filename)
@ -1933,7 +2058,11 @@ class YoutubeDL(object):
fd.add_progress_hook(ph)
if self.params.get('verbose'):
self.to_screen('[debug] Invoking downloader on %r' % info.get('url'))
return fd.download(name, info)
new_info = dict((k, v) for k, v in info.items() if not k.startswith('__p'))
new_info['http_headers'] = self._calc_headers(new_info)
return fd.download(name, new_info)
if info_dict.get('requested_formats') is not None:
downloaded = []
@ -2378,10 +2507,12 @@ class YoutubeDL(object):
self.get_encoding()))
write_string(encoding_str, encoding=None)
self._write_string('[debug] youtube-dl version ' + __version__ + (' (single file build)\n' if ytdl_is_updateable() else '\n'))
writeln_debug = lambda *s: self._write_string('[debug] %s\n' % (''.join(s), ))
writeln_debug('youtube-dl version ', __version__)
if _LAZY_LOADER:
self._write_string('[debug] Lazy loading extractors enabled\n')
writeln_debug = lambda *s: self._write_string('[debug] %s\n' % (''.join(s), )) # moved down for easier merge
writeln_debug('Lazy loading extractors enabled')
if ytdl_is_updateable():
writeln_debug('Single file build')
try:
sp = subprocess.Popen(
['git', 'rev-parse', '--short', 'HEAD'],
@ -2457,7 +2588,7 @@ class YoutubeDL(object):
opts_proxy = self.params.get('proxy')
if opts_cookiefile is None:
self.cookiejar = compat_cookiejar.CookieJar()
self.cookiejar = YoutubeDLCookieJar()
else:
opts_cookiefile = expand_path(opts_cookiefile)
self.cookiejar = YoutubeDLCookieJar(opts_cookiefile)
@ -2518,6 +2649,28 @@ class YoutubeDL(object):
encoding = preferredencoding()
return encoding
def _write_info_json(self, label, info_dict, infofn, overwrite=None):
if not self.params.get('writeinfojson', False):
return False
def msg(fmt, lbl):
return fmt % (lbl + ' metadata',)
if overwrite is None:
overwrite = not self.params.get('nooverwrites', False)
if not overwrite and os.path.exists(encodeFilename(infofn)):
self.to_screen(msg('[info] %s is already present', label.title()))
return 'exists'
else:
self.to_screen(msg('[info] Writing %s as JSON to: ' + infofn, label))
try:
write_json_file(self.filter_requested_info(info_dict), infofn)
return True
except (OSError, IOError):
self.report_error(msg('Cannot write %s to JSON file ' + infofn, label))
return
def _write_thumbnails(self, info_dict, filename):
if self.params.get('writethumbnail', False):
thumbnails = info_dict.get('thumbnails')

View File

@ -21,6 +21,7 @@ import socket
import struct
import subprocess
import sys
import types
import xml.etree.ElementTree
# naming convention
@ -55,6 +56,22 @@ try:
except ImportError: # Python 2
import urllib2 as compat_urllib_request
# Also fix up lack of method arg in old Pythons
try:
_req = compat_urllib_request.Request
_req('http://127.0.0.1', method='GET')
except TypeError:
class _request(object):
def __new__(cls, url, *args, **kwargs):
method = kwargs.pop('method', None)
r = _req(url, *args, **kwargs)
if method:
r.get_method = types.MethodType(lambda _: method, r)
return r
compat_urllib_request.Request = _request
try:
import urllib.error as compat_urllib_error
except ImportError: # Python 2
@ -79,6 +96,12 @@ try:
except ImportError: # Python 2
import urllib as compat_urllib_response
try:
compat_urllib_response.addinfourl.status
except AttributeError:
# .getcode() is deprecated in Py 3.
compat_urllib_response.addinfourl.status = property(lambda self: self.getcode())
try:
import http.cookiejar as compat_cookiejar
except ImportError: # Python 2
@ -103,12 +126,24 @@ except ImportError: # Python 2
import Cookie as compat_cookies
compat_http_cookies = compat_cookies
if sys.version_info[0] == 2:
if sys.version_info[0] == 2 or sys.version_info < (3, 3):
class compat_cookies_SimpleCookie(compat_cookies.SimpleCookie):
def load(self, rawdata):
must_have_value = 0
if not isinstance(rawdata, dict):
if sys.version_info[:2] != (2, 7):
# attribute must have value for parsing
rawdata, must_have_value = re.subn(
r'(?i)(;\s*)(secure|httponly)(\s*(?:;|$))', r'\1\2=\2\3', rawdata)
if sys.version_info[0] == 2:
if isinstance(rawdata, compat_str):
rawdata = str(rawdata)
return super(compat_cookies_SimpleCookie, self).load(rawdata)
super(compat_cookies_SimpleCookie, self).load(rawdata)
if must_have_value > 0:
for morsel in self.values():
for attr in ('secure', 'httponly'):
if morsel.get(attr):
morsel[attr] = True
else:
compat_cookies_SimpleCookie = compat_cookies.SimpleCookie
compat_http_cookies_SimpleCookie = compat_cookies_SimpleCookie
@ -2360,6 +2395,11 @@ try:
import http.client as compat_http_client
except ImportError: # Python 2
import httplib as compat_http_client
try:
compat_http_client.HTTPResponse.getcode
except AttributeError:
# Py < 3.1
compat_http_client.HTTPResponse.getcode = lambda self: self.status
try:
from urllib.error import HTTPError as compat_HTTPError

View File

@ -339,6 +339,10 @@ class FileDownloader(object):
def download(self, filename, info_dict):
"""Download to a filename using the info from info_dict
Return True on success and False otherwise
This method filters the `Cookie` header from the info_dict to prevent leaks.
Downloaders have their own way of handling cookies.
See: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-v8mc-9377-rwjj
"""
nooverwrites_and_exists = (

View File

@ -1,9 +1,10 @@
from __future__ import unicode_literals
import os.path
import os
import re
import subprocess
import sys
import tempfile
import time
from .common import FileDownloader
@ -23,6 +24,8 @@ from ..utils import (
check_executable,
is_outdated_version,
process_communicate_or_kill,
T,
traverse_obj,
)
@ -30,6 +33,7 @@ class ExternalFD(FileDownloader):
def real_download(self, filename, info_dict):
self.report_destination(filename)
tmpfilename = self.temp_name(filename)
self._cookies_tempfile = None
try:
started = time.time()
@ -42,6 +46,13 @@ class ExternalFD(FileDownloader):
# should take place
retval = 0
self.to_screen('[%s] Interrupted by user' % self.get_basename())
finally:
if self._cookies_tempfile and os.path.isfile(self._cookies_tempfile):
try:
os.remove(self._cookies_tempfile)
except OSError:
self.report_warning(
'Unable to delete temporary cookies file "{0}"'.format(self._cookies_tempfile))
if retval == 0:
status = {
@ -97,6 +108,16 @@ class ExternalFD(FileDownloader):
def _configuration_args(self, default=[]):
return cli_configuration_args(self.params, 'external_downloader_args', default)
def _write_cookies(self):
if not self.ydl.cookiejar.filename:
tmp_cookies = tempfile.NamedTemporaryFile(suffix='.cookies', delete=False)
tmp_cookies.close()
self._cookies_tempfile = tmp_cookies.name
self.to_screen('[download] Writing temporary cookies file to "{0}"'.format(self._cookies_tempfile))
# real_download resets _cookies_tempfile; if it's None, save() will write to cookiejar.filename
self.ydl.cookiejar.save(self._cookies_tempfile, ignore_discard=True, ignore_expires=True)
return self.ydl.cookiejar.filename or self._cookies_tempfile
def _call_downloader(self, tmpfilename, info_dict):
""" Either overwrite this or implement _make_cmd """
cmd = [encodeArgument(a) for a in self._make_cmd(tmpfilename, info_dict)]
@ -110,13 +131,21 @@ class ExternalFD(FileDownloader):
self.to_stderr(stderr.decode('utf-8', 'replace'))
return p.returncode
@staticmethod
def _header_items(info_dict):
return traverse_obj(
info_dict, ('http_headers', T(dict.items), Ellipsis))
class CurlFD(ExternalFD):
AVAILABLE_OPT = '-V'
def _make_cmd(self, tmpfilename, info_dict):
cmd = [self.exe, '--location', '-o', tmpfilename]
for key, val in info_dict['http_headers'].items():
cmd = [self.exe, '--location', '-o', tmpfilename, '--compressed']
cookie_header = self.ydl.cookiejar.get_cookie_header(info_dict['url'])
if cookie_header:
cmd += ['--cookie', cookie_header]
for key, val in self._header_items(info_dict):
cmd += ['--header', '%s: %s' % (key, val)]
cmd += self._bool_option('--continue-at', 'continuedl', '-', '0')
cmd += self._valueless_option('--silent', 'noprogress')
@ -151,8 +180,11 @@ class AxelFD(ExternalFD):
def _make_cmd(self, tmpfilename, info_dict):
cmd = [self.exe, '-o', tmpfilename]
for key, val in info_dict['http_headers'].items():
for key, val in self._header_items(info_dict):
cmd += ['-H', '%s: %s' % (key, val)]
cookie_header = self.ydl.cookiejar.get_cookie_header(info_dict['url'])
if cookie_header:
cmd += ['-H', 'Cookie: {0}'.format(cookie_header), '--max-redirect=0']
cmd += self._configuration_args()
cmd += ['--', info_dict['url']]
return cmd
@ -162,8 +194,10 @@ class WgetFD(ExternalFD):
AVAILABLE_OPT = '--version'
def _make_cmd(self, tmpfilename, info_dict):
cmd = [self.exe, '-O', tmpfilename, '-nv', '--no-cookies']
for key, val in info_dict['http_headers'].items():
cmd = [self.exe, '-O', tmpfilename, '-nv', '--compression=auto']
if self.ydl.cookiejar.get_cookie_header(info_dict['url']):
cmd += ['--load-cookies', self._write_cookies()]
for key, val in self._header_items(info_dict):
cmd += ['--header', '%s: %s' % (key, val)]
cmd += self._option('--limit-rate', 'ratelimit')
retry = self._option('--tries', 'retries')
@ -182,20 +216,57 @@ class WgetFD(ExternalFD):
class Aria2cFD(ExternalFD):
AVAILABLE_OPT = '-v'
@staticmethod
def _aria2c_filename(fn):
return fn if os.path.isabs(fn) else os.path.join('.', fn)
def _make_cmd(self, tmpfilename, info_dict):
cmd = [self.exe, '-c']
cmd += self._configuration_args([
'--min-split-size', '1M', '--max-connection-per-server', '4'])
dn = os.path.dirname(tmpfilename)
if dn:
cmd += ['--dir', dn]
cmd += ['--out', os.path.basename(tmpfilename)]
for key, val in info_dict['http_headers'].items():
cmd = [self.exe, '-c',
'--console-log-level=warn', '--summary-interval=0', '--download-result=hide',
'--http-accept-gzip=true', '--file-allocation=none', '-x16', '-j16', '-s16']
if 'fragments' in info_dict:
cmd += ['--allow-overwrite=true', '--allow-piece-length-change=true']
else:
cmd += ['--min-split-size', '1M']
if self.ydl.cookiejar.get_cookie_header(info_dict['url']):
cmd += ['--load-cookies={0}'.format(self._write_cookies())]
for key, val in self._header_items(info_dict):
cmd += ['--header', '%s: %s' % (key, val)]
cmd += self._configuration_args(['--max-connection-per-server', '4'])
cmd += ['--out', os.path.basename(tmpfilename)]
cmd += self._option('--max-overall-download-limit', 'ratelimit')
cmd += self._option('--interface', 'source_address')
cmd += self._option('--all-proxy', 'proxy')
cmd += self._bool_option('--check-certificate', 'nocheckcertificate', 'false', 'true', '=')
cmd += self._bool_option('--remote-time', 'updatetime', 'true', 'false', '=')
cmd += self._bool_option('--show-console-readout', 'noprogress', 'false', 'true', '=')
cmd += self._configuration_args()
# aria2c strips out spaces from the beginning/end of filenames and paths.
# We work around this issue by adding a "./" to the beginning of the
# filename and relative path, and adding a "/" at the end of the path.
# See: https://github.com/yt-dlp/yt-dlp/issues/276
# https://github.com/ytdl-org/youtube-dl/issues/20312
# https://github.com/aria2/aria2/issues/1373
dn = os.path.dirname(tmpfilename)
if dn:
cmd += ['--dir', self._aria2c_filename(dn) + os.path.sep]
if 'fragments' not in info_dict:
cmd += ['--out', self._aria2c_filename(os.path.basename(tmpfilename))]
cmd += ['--auto-file-renaming=false']
if 'fragments' in info_dict:
cmd += ['--file-allocation=none', '--uri-selector=inorder']
url_list_file = '%s.frag.urls' % (tmpfilename, )
url_list = []
for frag_index, fragment in enumerate(info_dict['fragments']):
fragment_filename = '%s-Frag%d' % (os.path.basename(tmpfilename), frag_index)
url_list.append('%s\n\tout=%s' % (fragment['url'], self._aria2c_filename(fragment_filename)))
stream, _ = self.sanitize_open(url_list_file, 'wb')
stream.write('\n'.join(url_list).encode())
stream.close()
cmd += ['-i', self._aria2c_filename(url_list_file)]
else:
cmd += ['--', info_dict['url']]
return cmd
@ -235,8 +306,10 @@ class Aria2pFD(ExternalFD):
}
options['dir'] = os.path.dirname(tmpfilename) or os.path.abspath('.')
options['out'] = os.path.basename(tmpfilename)
if self.ydl.cookiejar.get_cookie_header(info_dict['url']):
options['load-cookies'] = self._write_cookies()
options['header'] = []
for key, val in info_dict['http_headers'].items():
for key, val in self._header_items(info_dict):
options['header'].append('{0}: {1}'.format(key, val))
download = aria2.add_uris([info_dict['url']], options)
status = {
@ -265,8 +338,16 @@ class HttpieFD(ExternalFD):
def _make_cmd(self, tmpfilename, info_dict):
cmd = ['http', '--download', '--output', tmpfilename, info_dict['url']]
for key, val in info_dict['http_headers'].items():
for key, val in self._header_items(info_dict):
cmd += ['%s:%s' % (key, val)]
# httpie 3.1.0+ removes the Cookie header on redirect, so this should be safe for now. [1]
# If we ever need cookie handling for redirects, we can export the cookiejar into a session. [2]
# 1: https://github.com/httpie/httpie/security/advisories/GHSA-9w4w-cpc8-h2fq
# 2: https://httpie.io/docs/cli/sessions
cookie_header = self.ydl.cookiejar.get_cookie_header(info_dict['url'])
if cookie_header:
cmd += ['Cookie:%s' % cookie_header]
return cmd
@ -312,7 +393,14 @@ class FFmpegFD(ExternalFD):
# if end_time:
# args += ['-t', compat_str(end_time - start_time)]
if info_dict['http_headers'] and re.match(r'^https?://', url):
cookies = self.ydl.cookiejar.get_cookies_for_url(url)
if cookies:
args.extend(['-cookies', ''.join(
'{0}={1}; path={2}; domain={3};\r\n'.format(
cookie.name, cookie.value, cookie.path, cookie.domain)
for cookie in cookies)])
if info_dict.get('http_headers') and re.match(r'^https?://', url):
# Trailing \r\n after each HTTP header is important to prevent warning from ffmpeg/avconv:
# [http @ 00000000003d2fa0] No trailing CRLF found in HTTP header.
headers = handle_youtubedl_headers(info_dict['http_headers'])

View File

@ -280,16 +280,16 @@ class JSInterpreter(object):
# make Py 2.6 conform to its lying documentation
if name == 'flags':
self.flags = self.__flags
return self.flags
elif name == 'pattern':
self.pattern = self.__pattern_txt
return self.pattern
elif hasattr(self.__self, name):
v = getattr(self.__self, name)
setattr(self, name, v)
return v
elif name in ('groupindex', 'groups'):
# in case these get set after a match?
if hasattr(self.__self, name):
setattr(self, name, getattr(self.__self, name))
else:
return 0 if name == 'groupindex' else {}
if hasattr(self, name):
return getattr(self, name)
raise AttributeError('{0} has no attribute named {1}'.format(self, name))
@classmethod

View File

@ -544,12 +544,14 @@ def parseOpts(overrideArguments=None):
workarounds.add_option(
'--referer',
metavar='URL', dest='referer', default=None,
help='Specify a custom referer, use if the video access is restricted to one domain',
help='Specify a custom Referer: use if the video access is restricted to one domain',
)
workarounds.add_option(
'--add-header',
metavar='FIELD:VALUE', dest='headers', action='append',
help='Specify a custom HTTP header and its value, separated by a colon \':\'. You can use this option multiple times',
help=('Specify a custom HTTP header and its value, separated by a colon \':\'. You can use this option multiple times. '
'NB Use --cookies rather than adding a Cookie header if its contents may be sensitive; '
'data from a Cookie header will be sent to all domains, not just the one intended')
)
workarounds.add_option(
'--bidi-workaround',

View File

@ -16,6 +16,7 @@ import email.header
import errno
import functools
import gzip
import inspect
import io
import itertools
import json
@ -40,7 +41,6 @@ import zlib
from .compat import (
compat_HTMLParseError,
compat_HTMLParser,
compat_HTTPError,
compat_basestring,
compat_casefold,
compat_chr,
@ -63,6 +63,7 @@ from .compat import (
compat_struct_pack,
compat_struct_unpack,
compat_urllib_error,
compat_urllib_HTTPError,
compat_urllib_parse,
compat_urllib_parse_parse_qs as compat_parse_qs,
compat_urllib_parse_urlencode,
@ -2613,7 +2614,8 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
Part of this code was copied from:
http://techknack.net/python-urllib2-handlers/
http://techknack.net/python-urllib2-handlers/, archived at
https://web.archive.org/web/20130527205558/http://techknack.net/python-urllib2-handlers/
Andrew Rowls, the author of that code, agreed to release it to the
public domain.
@ -2671,7 +2673,9 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
req._Request__original = req._Request__original.partition('#')[0]
req._Request__r_type = req._Request__r_type.partition('#')[0]
return req
# Use the totally undocumented AbstractHTTPHandler per
# https://github.com/yt-dlp/yt-dlp/pull/4158
return compat_urllib_request.AbstractHTTPHandler.do_request_(self, req)
def http_response(self, req, resp):
old_resp = resp
@ -2682,7 +2686,7 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
try:
uncompressed = io.BytesIO(gz.read())
except IOError as original_ioerror:
# There may be junk add the end of the file
# There may be junk at the end of the file
# See http://stackoverflow.com/q/4928560/35070 for details
for i in range(1, 1024):
try:
@ -2709,8 +2713,7 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
if location:
# As of RFC 2616 default charset is iso-8859-1 that is respected by python 3
if sys.version_info >= (3, 0):
location = location.encode('iso-8859-1').decode('utf-8')
else:
location = location.encode('iso-8859-1')
location = location.decode('utf-8')
location_escaped = escape_url(location)
if location != location_escaped:
@ -2909,6 +2912,19 @@ class YoutubeDLCookieJar(compat_cookiejar.MozillaCookieJar):
cookie.expires = None
cookie.discard = True
def get_cookie_header(self, url):
"""Generate a Cookie HTTP header for a given url"""
cookie_req = sanitized_Request(url)
self.add_cookie_header(cookie_req)
return cookie_req.get_header('Cookie')
def get_cookies_for_url(self, url):
"""Generate a list of Cookie objects for a given url"""
# Policy `_now` attribute must be set before calling `_cookies_for_request`
# Ref: https://github.com/python/cpython/blob/3.7/Lib/http/cookiejar.py#L1360
self._policy._now = self._now = int(time.time())
return self._cookies_for_request(sanitized_Request(url))
class YoutubeDLCookieProcessor(compat_urllib_request.HTTPCookieProcessor):
def __init__(self, cookiejar=None):
@ -2939,17 +2955,16 @@ class YoutubeDLRedirectHandler(compat_urllib_request.HTTPRedirectHandler):
The code is based on HTTPRedirectHandler implementation from CPython [1].
This redirect handler solves two issues:
- ensures redirect URL is always unicode under python 2
- introduces support for experimental HTTP response status code
308 Permanent Redirect [2] used by some sites [3]
This redirect handler fixes and improves the logic to better align with RFC7261
and what browsers tend to do [2][3]
1. https://github.com/python/cpython/blob/master/Lib/urllib/request.py
2. https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/308
3. https://github.com/ytdl-org/youtube-dl/issues/28768
2. https://datatracker.ietf.org/doc/html/rfc7231
3. https://github.com/python/cpython/issues/91306
"""
http_error_301 = http_error_303 = http_error_307 = http_error_308 = compat_urllib_request.HTTPRedirectHandler.http_error_302
# Supply possibly missing alias
http_error_308 = compat_urllib_request.HTTPRedirectHandler.http_error_302
def redirect_request(self, req, fp, code, msg, headers, newurl):
"""Return a Request or None in response to a redirect.
@ -2961,19 +2976,15 @@ class YoutubeDLRedirectHandler(compat_urllib_request.HTTPRedirectHandler):
else should try to handle this url. Return None if you can't
but another Handler might.
"""
m = req.get_method()
if (not (code in (301, 302, 303, 307, 308) and m in ("GET", "HEAD")
or code in (301, 302, 303) and m == "POST")):
raise compat_HTTPError(req.full_url, code, msg, headers, fp)
# Strictly (according to RFC 2616), 301 or 302 in response to
# a POST MUST NOT cause a redirection without confirmation
# from the user (of urllib.request, in this case). In practice,
# essentially all clients do redirect in this case, so we do
# the same.
if code not in (301, 302, 303, 307, 308):
raise compat_urllib_HTTPError(req.full_url, code, msg, headers, fp)
new_method = req.get_method()
new_data = req.data
# On python 2 urlh.geturl() may sometimes return redirect URL
# as byte string instead of unicode. This workaround allows
# to force it always return unicode.
# as a byte string instead of unicode. This workaround forces
# it to return unicode.
if sys.version_info[0] < 3:
newurl = compat_str(newurl)
@ -2982,13 +2993,34 @@ class YoutubeDLRedirectHandler(compat_urllib_request.HTTPRedirectHandler):
# but it is kept for compatibility with other callers.
newurl = newurl.replace(' ', '%20')
CONTENT_HEADERS = ("content-length", "content-type")
# Technically the Cookie header should be in unredirected_hdrs;
# however in practice some may set it in normal headers anyway.
# We will remove it here to prevent any leaks.
remove_headers = ['Cookie']
# A 303 must either use GET or HEAD for subsequent request
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4
if code == 303 and req.get_method() != 'HEAD':
new_method = 'GET'
# 301 and 302 redirects are commonly turned into a GET from a POST
# for subsequent requests by browsers, so we'll do the same.
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3
elif code in (301, 302) and req.get_method() == 'POST':
new_method = 'GET'
# only remove payload if method changed (e.g. POST to GET)
if new_method != req.get_method():
new_data = None
remove_headers.extend(['Content-Length', 'Content-Type'])
# NB: don't use dict comprehension for python 2.6 compatibility
newheaders = dict((k, v) for k, v in req.headers.items()
if k.lower() not in CONTENT_HEADERS)
new_headers = dict((k, v) for k, v in req.header_items()
if k.title() not in remove_headers)
return compat_urllib_request.Request(
newurl, headers=newheaders, origin_req_host=req.origin_req_host,
unverifiable=True)
newurl, headers=new_headers, origin_req_host=req.origin_req_host,
unverifiable=True, method=new_method, data=new_data)
def extract_timezone(date_str):
@ -3881,7 +3913,7 @@ def detect_exe_version(output, version_re=None, unrecognized='present'):
return unrecognized
class LazyList(compat_collections_abc.Sequence):
class LazyList(compat_collections_abc.Iterable):
"""Lazy immutable list from an iterable
Note that slices of a LazyList are lists and not LazyList"""
@ -4223,10 +4255,16 @@ def multipart_encode(data, boundary=None):
return out, content_type
def variadic(x, allowed_types=(compat_str, bytes, dict)):
if not isinstance(allowed_types, tuple) and isinstance(allowed_types, compat_collections_abc.Iterable):
def is_iterable_like(x, allowed_types=compat_collections_abc.Iterable, blocked_types=NO_DEFAULT):
if blocked_types is NO_DEFAULT:
blocked_types = (compat_str, bytes, compat_collections_abc.Mapping)
return isinstance(x, allowed_types) and not isinstance(x, blocked_types)
def variadic(x, allowed_types=NO_DEFAULT):
if isinstance(allowed_types, compat_collections_abc.Iterable):
allowed_types = tuple(allowed_types)
return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,)
return x if is_iterable_like(x, blocked_types=allowed_types) else (x,)
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
@ -5993,7 +6031,7 @@ def clean_podcast_url(url):
def traverse_obj(obj, *paths, **kwargs):
"""
Safely traverse nested `dict`s and `Sequence`s
Safely traverse nested `dict`s and `Iterable`s
>>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key"))
@ -6001,14 +6039,17 @@ def traverse_obj(obj, *paths, **kwargs):
Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
A value of None is treated as the absence of a value.
Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
The keys in the path can be one of:
- `None`: Return the current object.
- `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`.
- `set`: Requires the only item in the set to be a type or function,
like `{type}`/`{func}`. If a `type`, returns only values
of this type. If a function, returns `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values.
- `tuple`/`list`: Branch out and return a list of all matching values.
@ -6016,6 +6057,9 @@ def traverse_obj(obj, *paths, **kwargs):
- `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`.
For `Sequence`s, `key` is the index of the value.
For `Iterable`s, `key` is the enumeration count of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict` Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
@ -6024,8 +6068,12 @@ def traverse_obj(obj, *paths, **kwargs):
@params paths Paths which to traverse by.
Keyword arguments:
@param default Value to return if the paths do not match.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, depth first. Try to avoid if using nested `dict` keys.
@param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, recursively. This does respect branching paths.
@param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive.
@ -6036,12 +6084,15 @@ def traverse_obj(obj, *paths, **kwargs):
@param _traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be
converted into a string and then traversed into.
The return value of that path will be a string instead,
not respecting any further branching.
@returns The result of the object traversal.
If successful, `get_all=True`, and the path branches at least once,
then a list of results is returned instead.
A list is always returned if the last path branches and no `default` is given.
If a path ends on a `dict` that result will always be a `dict`.
"""
# parameter defaults
@ -6055,7 +6106,6 @@ def traverse_obj(obj, *paths, **kwargs):
# instant compat
str = compat_str
is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes))
casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k
if isinstance(expected_type, type):
@ -6063,128 +6113,184 @@ def traverse_obj(obj, *paths, **kwargs):
else:
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def lookup_or_none(v, k, getter=None):
try:
return getter(v, k) if getter else v[k]
except IndexError:
return None
def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
for it in iterables:
for item in it:
yield item
def apply_key(key, obj):
if obj is None:
return
def apply_key(key, obj, is_last):
branching = False
if obj is None and _traverse_string:
if key is Ellipsis or callable(key) or isinstance(key, slice):
branching = True
result = ()
else:
result = None
elif key is None:
yield obj
result = obj
elif isinstance(key, set):
assert len(key) == 1, 'Set should only be used to wrap a single item'
item = next(iter(key))
if isinstance(item, type):
result = obj if isinstance(obj, item) else None
else:
result = try_call(item, args=(obj,))
elif isinstance(key, (list, tuple)):
for branch in key:
_, result = apply_path(obj, branch)
for item in result:
yield item
branching = True
result = from_iterable(
apply_path(obj, branch, is_last)[0] for branch in key)
elif key is Ellipsis:
result = []
branching = True
if isinstance(obj, compat_collections_abc.Mapping):
result = obj.values()
elif is_sequence(obj):
elif is_iterable_like(obj):
result = obj
elif isinstance(obj, compat_re_Match):
result = obj.groups()
elif _traverse_string:
branching = False
result = str(obj)
for item in result:
yield item
else:
result = ()
elif callable(key):
if is_sequence(obj):
iter_obj = enumerate(obj)
elif isinstance(obj, compat_collections_abc.Mapping):
branching = True
if isinstance(obj, compat_collections_abc.Mapping):
iter_obj = obj.items()
elif is_iterable_like(obj):
iter_obj = enumerate(obj)
elif isinstance(obj, compat_re_Match):
iter_obj = enumerate(itertools.chain([obj.group()], obj.groups()))
iter_obj = itertools.chain(
enumerate(itertools.chain((obj.group(),), obj.groups())),
obj.groupdict().items())
elif _traverse_string:
branching = False
iter_obj = enumerate(str(obj))
else:
return
for item in (v for k, v in iter_obj if try_call(key, args=(k, v))):
yield item
iter_obj = ()
result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
if not branching: # string traversal
result = ''.join(result)
elif isinstance(key, dict):
iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items())
yield dict((k, v if v is not None else default) for k, v in iter_obj
if v is not None or default is not NO_DEFAULT)
iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
result = dict((k, v if v is not None else default) for k, v in iter_obj
if v is not None or default is not NO_DEFAULT) or None
elif isinstance(obj, compat_collections_abc.Mapping):
yield (obj.get(key) if casesense or (key in obj)
result = (try_call(obj.get, args=(key,))
if casesense or try_call(obj.__contains__, args=(key,))
else next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, compat_re_Match):
result = None
if isinstance(key, int) or casesense:
try:
yield obj.group(key)
return
except IndexError:
pass
if not isinstance(key, str):
return
# Py 2.6 doesn't have methods in the Match class/type
result = lookup_or_none(obj, key, getter=lambda _, k: obj.group(k))
yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
elif isinstance(key, str):
result = next((v for k, v in obj.groupdict().items()
if casefold(k) == key), None)
else:
if _is_user_input:
key = (int_or_none(key) if ':' not in key
else slice(*map(int_or_none, key.split(':'))))
result = None
if isinstance(key, (int, slice)):
if is_iterable_like(obj, compat_collections_abc.Sequence):
branching = isinstance(key, slice)
result = lookup_or_none(obj, key)
elif _traverse_string:
result = lookup_or_none(str(obj), key)
if not isinstance(key, (int, slice)):
return branching, result if branching else (result,)
def lazy_last(iterable):
iterator = iter(iterable)
prev = next(iterator, NO_DEFAULT)
if prev is NO_DEFAULT:
return
if not is_sequence(obj):
if not _traverse_string:
return
obj = str(obj)
for item in iterator:
yield False, prev
prev = item
try:
yield obj[key]
except IndexError:
pass
yield True, prev
def apply_path(start_obj, path):
def apply_path(start_obj, path, test_type):
objs = (start_obj,)
has_branched = False
for key in variadic(path):
if _is_user_input and key == ':':
key = None
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
if _is_user_input and isinstance(key, str):
if key == ':':
key = Ellipsis
elif ':' in key:
key = slice(*map(int_or_none, key.split(':')))
elif int_or_none(key) is not None:
key = int(key)
if not casesense and isinstance(key, str):
key = compat_casefold(key)
if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key):
has_branched = True
if __debug__ and callable(key):
# Verify function signature
args = inspect.getargspec(key)
if len(args.args) != 2:
# crash differently in 2.6 !
inspect.getcallargs(key, None, None)
key_func = functools.partial(apply_key, key)
objs = from_iterable(map(key_func, objs))
new_objs = []
for obj in objs:
branching, results = apply_key(key, obj, last)
has_branched |= branching
new_objs.append(results)
return has_branched, objs
objs = from_iterable(new_objs)
def _traverse_obj(obj, path, use_list=True):
has_branched, results = apply_path(obj, path)
results = LazyList(x for x in map(type_test, results) if x is not None)
if test_type and not isinstance(key, (dict, list, tuple)):
objs = map(type_test, objs)
return objs, has_branched, isinstance(key, dict)
def _traverse_obj(obj, path, allow_empty, test_type):
results, has_branched, is_dict = apply_path(obj, path, test_type)
results = LazyList(x for x in results if x not in (None, {}))
if get_all and has_branched:
return results.exhaust() if results or use_list else None
if results:
return results.exhaust()
if allow_empty:
return [] if default is NO_DEFAULT else default
return None
return results[0] if results else None
return results[0] if results else {} if allow_empty and is_dict else None
for index, path in enumerate(paths, 1):
use_list = default is NO_DEFAULT and index == len(paths)
result = _traverse_obj(obj, path, use_list)
result = _traverse_obj(obj, path, index == len(paths), True)
if result is not None:
return result
return None if default is NO_DEFAULT else default
def T(x):
""" For use in yt-dl instead of {type} or set((type,)) """
return set((x,))
def get_first(obj, keys, **kwargs):
return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)