Compare commits

...

13 Commits

Author SHA1 Message Date
dirkf
47214e46d8 [compat] Fix old Pythons broken loading of valueless cookie attributes
Cookie string parsing in Py 2.6.9, probably earlier, requires `=`.
Also 3.2, though the CPython code appears to be OK: 3.1 was also wrong.
2023-07-18 10:50:46 +01:00
dirkf
1d8d5a93f7 [test] Fixes for old Pythons 2023-07-18 10:50:46 +01:00
dirkf
1634b1d61e [doc] Warn against setting cookies with --add-header 2023-07-18 10:50:46 +01:00
bashonly
21438a4194 [downloader/external] Fix cookie support 2023-07-18 10:50:46 +01:00
Simon Sawicki
8334ec961b [core] Process header cookies on loading 2023-07-18 10:50:46 +01:00
bashonly
3801d36416 [utils] YoutubeDLCookieJar: Add get_cookie_header and get_cookies_for_url methods 2023-07-18 10:50:46 +01:00
dirkf
b383be9887 [core] Remove Cookie header on redirect to prevent leaks
Adated from yt-dlp/yt-dlp-ghsa-v8mc-9377-rwjj/pull/1/commits/101caac
Thx coletdjnz
2023-07-18 10:50:46 +01:00
dirkf
46fde7caee [core] Update redirect handling from yt-dlp
* Thx coletdjnz: https://github.com/yt-dlp/yt-dlp/pull/7094
* add test that redirected `POST` loses its `Content-Type`
2023-07-18 10:50:46 +01:00
dirkf
648dc5304c [compat] Add Request and HTTPClient compat for redirect
* support `method` parameter of `Request.__init__`  (Py 2 and old Py 3)
* support `getcode` method of compat_http_client.HTTPResponse (Py 2)
2023-07-18 10:50:46 +01:00
dirkf
1720c04dc5 [test] Make skipped tests in test_execution work with Py 2.6 2023-07-18 10:50:46 +01:00
dirkf
d5ef405c5d [core] Align error reporting methods with yt-dlp 2023-07-18 10:50:46 +01:00
dirkf
f47fdb9564 [utils] Add {expected_type} and Iterable support to traverse_obj() 2023-07-18 10:50:46 +01:00
dirkf
b6dff4073d [core] Revert version display from b8a86dc 2023-07-18 10:50:46 +01:00
15 changed files with 1524 additions and 276 deletions

View File

@ -301,7 +301,7 @@ jobs:
if: ${{ matrix.python-version == '2.6' }} if: ${{ matrix.python-version == '2.6' }}
shell: bash shell: bash
run: | run: |
# see pip for Jython # Work around deprecation of support for non-SNI clients at PyPI CDN (see https://status.python.org/incidents/hzmjhqsdjqgb)
$PIP -qq show unittest2 || { \ $PIP -qq show unittest2 || { \
for u in "65/26/32b8464df2a97e6dd1b656ed26b2c194606c16fe163c695a992b36c11cdf/six-1.13.0-py2.py3-none-any.whl" \ for u in "65/26/32b8464df2a97e6dd1b656ed26b2c194606c16fe163c695a992b36c11cdf/six-1.13.0-py2.py3-none-any.whl" \
"f2/94/3af39d34be01a24a6e65433d19e107099374224905f1e0cc6bbe1fd22a2f/argparse-1.4.0-py2.py3-none-any.whl" \ "f2/94/3af39d34be01a24a6e65433d19e107099374224905f1e0cc6bbe1fd22a2f/argparse-1.4.0-py2.py3-none-any.whl" \
@ -312,7 +312,7 @@ jobs:
$PIP install ${u##*/}; \ $PIP install ${u##*/}; \
done; } done; }
# make tests use unittest2 # make tests use unittest2
for test in ./test/test_*.py; do for test in ./test/test_*.py ./test/helper.py; do
sed -r -i -e '/^import unittest$/s/test/test2 as unittest/' "$test" sed -r -i -e '/^import unittest$/s/test/test2 as unittest/' "$test"
done done
#-------- nose -------- #-------- nose --------

View File

@ -9,6 +9,7 @@ import re
import types import types
import ssl import ssl
import sys import sys
import unittest
import youtube_dl.extractor import youtube_dl.extractor
from youtube_dl import YoutubeDL from youtube_dl import YoutubeDL
@ -17,6 +18,7 @@ from youtube_dl.compat import (
compat_str, compat_str,
) )
from youtube_dl.utils import ( from youtube_dl.utils import (
IDENTITY,
preferredencoding, preferredencoding,
write_string, write_string,
) )
@ -72,7 +74,8 @@ class FakeYDL(YoutubeDL):
def to_screen(self, s, skip_eol=None): def to_screen(self, s, skip_eol=None):
print(s) print(s)
def trouble(self, s, tb=None): def trouble(self, *args, **kwargs):
s = args[0] if len(args) > 0 else kwargs.get('message', 'Missing message')
raise Exception(s) raise Exception(s)
def download(self, x): def download(self, x):
@ -297,3 +300,7 @@ def http_server_port(httpd):
else: else:
sock = httpd.socket sock = httpd.socket
return sock.getsockname()[1] return sock.getsockname()[1]
def expectedFailureIf(cond):
return unittest.expectedFailure if cond else IDENTITY

View File

@ -10,14 +10,30 @@ import unittest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import copy import copy
import json
from test.helper import FakeYDL, assertRegexpMatches from test.helper import (
FakeYDL,
assertRegexpMatches,
try_rm,
)
from youtube_dl import YoutubeDL from youtube_dl import YoutubeDL
from youtube_dl.compat import compat_str, compat_urllib_error from youtube_dl.compat import (
compat_http_cookiejar_Cookie,
compat_http_cookies_SimpleCookie,
compat_kwargs,
compat_str,
compat_urllib_error,
)
from youtube_dl.extractor import YoutubeIE from youtube_dl.extractor import YoutubeIE
from youtube_dl.extractor.common import InfoExtractor from youtube_dl.extractor.common import InfoExtractor
from youtube_dl.postprocessor.common import PostProcessor from youtube_dl.postprocessor.common import PostProcessor
from youtube_dl.utils import ExtractorError, match_filter_func from youtube_dl.utils import (
ExtractorError,
match_filter_func,
traverse_obj,
)
TEST_URL = 'http://localhost/sample.mp4' TEST_URL = 'http://localhost/sample.mp4'
@ -29,11 +45,14 @@ class YDL(FakeYDL):
self.msgs = [] self.msgs = []
def process_info(self, info_dict): def process_info(self, info_dict):
self.downloaded_info_dicts.append(info_dict) self.downloaded_info_dicts.append(info_dict.copy())
def to_screen(self, msg): def to_screen(self, msg):
self.msgs.append(msg) self.msgs.append(msg)
def dl(self, *args, **kwargs):
assert False, 'Downloader must not be invoked for test_YoutubeDL'
def _make_result(formats, **kwargs): def _make_result(formats, **kwargs):
res = { res = {
@ -42,8 +61,9 @@ def _make_result(formats, **kwargs):
'title': 'testttitle', 'title': 'testttitle',
'extractor': 'testex', 'extractor': 'testex',
'extractor_key': 'TestEx', 'extractor_key': 'TestEx',
'webpage_url': 'http://example.com/watch?v=shenanigans',
} }
res.update(**kwargs) res.update(**compat_kwargs(kwargs))
return res return res
@ -930,17 +950,11 @@ class TestYoutubeDL(unittest.TestCase):
# Test case for https://github.com/ytdl-org/youtube-dl/issues/27064 # Test case for https://github.com/ytdl-org/youtube-dl/issues/27064
def test_ignoreerrors_for_playlist_with_url_transparent_iterable_entries(self): def test_ignoreerrors_for_playlist_with_url_transparent_iterable_entries(self):
class _YDL(YDL): ydl = YDL({
def __init__(self, *args, **kwargs):
super(_YDL, self).__init__(*args, **kwargs)
def trouble(self, s, tb=None):
pass
ydl = _YDL({
'format': 'extra', 'format': 'extra',
'ignoreerrors': True, 'ignoreerrors': True,
}) })
ydl.trouble = lambda *_, **__: None
class VideoIE(InfoExtractor): class VideoIE(InfoExtractor):
_VALID_URL = r'video:(?P<id>\d+)' _VALID_URL = r'video:(?P<id>\d+)'
@ -1017,5 +1031,160 @@ class TestYoutubeDL(unittest.TestCase):
self.assertEqual(out_info['release_date'], '20210930') self.assertEqual(out_info['release_date'], '20210930')
class TestYoutubeDLCookies(unittest.TestCase):
@staticmethod
def encode_cookie(cookie):
if not isinstance(cookie, dict):
cookie = vars(cookie)
for name, value in cookie.items():
yield name, compat_str(value)
@classmethod
def comparable_cookies(cls, cookies):
# Work around cookiejar cookies not being unicode strings
return sorted(map(tuple, map(sorted, map(cls.encode_cookie, cookies))))
def assertSameCookies(self, c1, c2, msg=None):
return self.assertEqual(
*map(self.comparable_cookies, (c1, c2)),
msg=msg)
def assertSameCookieStrings(self, c1, c2, msg=None):
return self.assertSameCookies(
*map(lambda c: compat_http_cookies_SimpleCookie(c).values(), (c1, c2)),
msg=msg)
def test_header_cookies(self):
ydl = FakeYDL()
ydl.report_warning = lambda *_, **__: None
def cookie(name, value, version=None, domain='', path='', secure=False, expires=None):
return compat_http_cookiejar_Cookie(
version or 0, name, value, None, False,
domain, bool(domain), bool(domain), path, bool(path),
secure, expires, False, None, None, rest={})
test_url, test_domain = (t % ('yt.dl',) for t in ('https://%s/test', '.%s'))
def test(encoded_cookies, cookies, headers=False, round_trip=None, error_re=None):
def _test():
ydl.cookiejar.clear()
ydl._load_cookies(encoded_cookies, autoscope=headers)
if headers:
ydl._apply_header_cookies(test_url)
data = {'url': test_url}
ydl._calc_headers(data)
self.assertSameCookies(
cookies, ydl.cookiejar,
'Extracted cookiejar.Cookie is not the same')
if not headers:
self.assertSameCookieStrings(
data.get('cookies'), round_trip or encoded_cookies,
msg='Cookie is not the same as round trip')
ydl.__dict__['_YoutubeDL__header_cookies'] = []
try:
_test()
except AssertionError:
raise
except Exception as e:
if not error_re:
raise
assertRegexpMatches(self, e.args[0], error_re.join(('.*',) * 2))
test('test=value; Domain=' + test_domain, [cookie('test', 'value', domain=test_domain)])
test('test=value', [cookie('test', 'value')], error_re='Unscoped cookies are not allowed')
test('cookie1=value1; Domain={0}; Path=/test; cookie2=value2; Domain={0}; Path=/'.format(test_domain), [
cookie('cookie1', 'value1', domain=test_domain, path='/test'),
cookie('cookie2', 'value2', domain=test_domain, path='/')])
cookie_kw = compat_kwargs(
{'domain': test_domain, 'path': '/test', 'secure': True, 'expires': '9999999999', })
test('test=value; Domain={domain}; Path={path}; Secure; Expires={expires}'.format(**cookie_kw), [
cookie('test', 'value', **cookie_kw)])
test('test="value; "; path=/test; domain=' + test_domain, [
cookie('test', 'value; ', domain=test_domain, path='/test')],
round_trip='test="value\\073 "; Domain={0}; Path=/test'.format(test_domain))
test('name=; Domain=' + test_domain, [cookie('name', '', domain=test_domain)],
round_trip='name=""; Domain=' + test_domain)
test('test=value', [cookie('test', 'value', domain=test_domain)], headers=True)
test('cookie1=value; Domain={0}; cookie2=value'.format(test_domain), [],
headers=True, error_re='Invalid syntax')
ydl.report_warning = ydl.report_error
test('test=value', [], headers=True, error_re='Passing cookies as a header is a potential security risk')
def test_infojson_cookies(self):
TEST_FILE = 'test_infojson_cookies.info.json'
TEST_URL = 'https://example.com/example.mp4'
COOKIES = 'a=b; Domain=.example.com; c=d; Domain=.example.com'
COOKIE_HEADER = {'Cookie': 'a=b; c=d'}
ydl = FakeYDL()
ydl.process_info = lambda x: ydl._write_info_json('test', x, TEST_FILE)
def make_info(info_header_cookies=False, fmts_header_cookies=False, cookies_field=False):
fmt = {'url': TEST_URL}
if fmts_header_cookies:
fmt['http_headers'] = COOKIE_HEADER
if cookies_field:
fmt['cookies'] = COOKIES
return _make_result([fmt], http_headers=COOKIE_HEADER if info_header_cookies else None)
def test(initial_info, note):
def failure_msg(why):
return ' when '.join((why, note))
result = {}
result['processed'] = ydl.process_ie_result(initial_info)
self.assertTrue(ydl.cookiejar.get_cookies_for_url(TEST_URL),
msg=failure_msg('No cookies set in cookiejar after initial process'))
ydl.cookiejar.clear()
with open(TEST_FILE) as infojson:
result['loaded'] = ydl.sanitize_info(json.load(infojson), True)
result['final'] = ydl.process_ie_result(result['loaded'].copy(), download=False)
self.assertTrue(ydl.cookiejar.get_cookies_for_url(TEST_URL),
msg=failure_msg('No cookies set in cookiejar after final process'))
ydl.cookiejar.clear()
for key in ('processed', 'loaded', 'final'):
info = result[key]
self.assertIsNone(
traverse_obj(info, ((None, ('formats', 0)), 'http_headers', 'Cookie'), casesense=False, get_all=False),
msg=failure_msg('Cookie header not removed in {0} result'.format(key)))
self.assertSameCookieStrings(
traverse_obj(info, ((None, ('formats', 0)), 'cookies'), get_all=False), COOKIES,
msg=failure_msg('No cookies field found in {0} result'.format(key)))
test({'url': TEST_URL, 'http_headers': COOKIE_HEADER, 'id': '1', 'title': 'x'}, 'no formats field')
test(make_info(info_header_cookies=True), 'info_dict header cokies')
test(make_info(fmts_header_cookies=True), 'format header cookies')
test(make_info(info_header_cookies=True, fmts_header_cookies=True), 'info_dict and format header cookies')
test(make_info(info_header_cookies=True, fmts_header_cookies=True, cookies_field=True), 'all cookies fields')
test(make_info(cookies_field=True), 'cookies format field')
test({'url': TEST_URL, 'cookies': COOKIES, 'id': '1', 'title': 'x'}, 'info_dict cookies field only')
try_rm(TEST_FILE)
def test_add_headers_cookie(self):
def check_for_cookie_header(result):
return traverse_obj(result, ((None, ('formats', 0)), 'http_headers', 'Cookie'), casesense=False, get_all=False)
ydl = FakeYDL({'http_headers': {'Cookie': 'a=b'}})
ydl._apply_header_cookies(_make_result([])['webpage_url']) # Scope to input webpage URL: .example.com
fmt = {'url': 'https://example.com/video.mp4'}
result = ydl.process_ie_result(_make_result([fmt]), download=False)
self.assertIsNone(check_for_cookie_header(result), msg='http_headers cookies in result info_dict')
self.assertEqual(result.get('cookies'), 'a=b; Domain=.example.com', msg='No cookies were set in cookies field')
self.assertIn('a=b', ydl.cookiejar.get_cookie_header(fmt['url']), msg='No cookies were set in cookiejar')
fmt = {'url': 'https://wrong.com/video.mp4'}
result = ydl.process_ie_result(_make_result([fmt]), download=False)
self.assertIsNone(check_for_cookie_header(result), msg='http_headers cookies for wrong domain')
self.assertFalse(result.get('cookies'), msg='Cookies set in cookies field for wrong domain')
self.assertFalse(ydl.cookiejar.get_cookie_header(fmt['url']), msg='Cookies set in cookiejar for wrong domain')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -46,6 +46,20 @@ class TestYoutubeDLCookieJar(unittest.TestCase):
# will be ignored # will be ignored
self.assertFalse(cookiejar._cookies) self.assertFalse(cookiejar._cookies)
def test_get_cookie_header(self):
cookiejar = YoutubeDLCookieJar('./test/testdata/cookies/httponly_cookies.txt')
cookiejar.load(ignore_discard=True, ignore_expires=True)
header = cookiejar.get_cookie_header('https://www.foobar.foobar')
self.assertIn('HTTPONLY_COOKIE', header)
def test_get_cookies_for_url(self):
cookiejar = YoutubeDLCookieJar('./test/testdata/cookies/session_cookies.txt')
cookiejar.load(ignore_discard=True, ignore_expires=True)
cookies = cookiejar.get_cookies_for_url('https://www.foobar.foobar/')
self.assertEqual(len(cookies), 2)
cookies = cookiejar.get_cookies_for_url('https://foobar.foobar/')
self.assertFalse(cookies)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -12,20 +12,65 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from test.helper import ( from test.helper import (
FakeLogger, FakeLogger,
FakeYDL,
http_server_port, http_server_port,
try_rm, try_rm,
) )
from youtube_dl import YoutubeDL from youtube_dl import YoutubeDL
from youtube_dl.compat import compat_http_server from youtube_dl.compat import (
from youtube_dl.utils import encodeFilename compat_http_cookiejar_Cookie,
from youtube_dl.downloader.external import Aria2pFD compat_http_server,
compat_kwargs,
)
from youtube_dl.utils import (
encodeFilename,
join_nonempty,
)
from youtube_dl.downloader.external import (
Aria2cFD,
Aria2pFD,
AxelFD,
CurlFD,
FFmpegFD,
HttpieFD,
WgetFD,
)
import threading import threading
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_SIZE = 10 * 1024 TEST_SIZE = 10 * 1024
TEST_COOKIE = {
'version': 0,
'name': 'test',
'value': 'ytdlp',
'port': None,
'port_specified': False,
'domain': '.example.com',
'domain_specified': True,
'domain_initial_dot': False,
'path': '/',
'path_specified': True,
'secure': False,
'expires': None,
'discard': False,
'comment': None,
'comment_url': None,
'rest': {},
}
TEST_COOKIE_VALUE = join_nonempty('name', 'value', delim='=', from_dict=TEST_COOKIE)
TEST_INFO = {'url': 'http://www.example.com/'}
def cookiejar_Cookie(**cookie_args):
return compat_http_cookiejar_Cookie(**compat_kwargs(cookie_args))
def ifExternalFDAvailable(externalFD):
return unittest.skipUnless(externalFD.available(),
externalFD.get_basename() + ' not found')
class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler): class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler):
def log_message(self, format, *args): def log_message(self, format, *args):
@ -70,7 +115,7 @@ class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler):
assert False, 'unrecognised server path' assert False, 'unrecognised server path'
@unittest.skipUnless(Aria2pFD.available(), 'aria2p module not found') @ifExternalFDAvailable(Aria2pFD)
class TestAria2pFD(unittest.TestCase): class TestAria2pFD(unittest.TestCase):
def setUp(self): def setUp(self):
self.httpd = compat_http_server.HTTPServer( self.httpd = compat_http_server.HTTPServer(
@ -111,5 +156,103 @@ class TestAria2pFD(unittest.TestCase):
}) })
@ifExternalFDAvailable(HttpieFD)
class TestHttpieFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = HttpieFD(ydl, {})
self.assertEqual(
downloader._make_cmd('test', TEST_INFO),
['http', '--download', '--output', 'test', 'http://www.example.com/'])
# Test cookie header is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
self.assertEqual(
downloader._make_cmd('test', TEST_INFO),
['http', '--download', '--output', 'test',
'http://www.example.com/', 'Cookie:' + TEST_COOKIE_VALUE])
@ifExternalFDAvailable(AxelFD)
class TestAxelFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = AxelFD(ydl, {})
self.assertEqual(
downloader._make_cmd('test', TEST_INFO),
['axel', '-o', 'test', '--', 'http://www.example.com/'])
# Test cookie header is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
self.assertEqual(
downloader._make_cmd('test', TEST_INFO),
['axel', '-o', 'test', '-H', 'Cookie: ' + TEST_COOKIE_VALUE,
'--max-redirect=0', '--', 'http://www.example.com/'])
@ifExternalFDAvailable(WgetFD)
class TestWgetFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = WgetFD(ydl, {})
self.assertNotIn('--load-cookies', downloader._make_cmd('test', TEST_INFO))
# Test cookiejar tempfile arg is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
self.assertIn('--load-cookies', downloader._make_cmd('test', TEST_INFO))
@ifExternalFDAvailable(CurlFD)
class TestCurlFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = CurlFD(ydl, {})
self.assertNotIn('--cookie', downloader._make_cmd('test', TEST_INFO))
# Test cookie header is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
self.assertIn('--cookie', downloader._make_cmd('test', TEST_INFO))
self.assertIn(TEST_COOKIE_VALUE, downloader._make_cmd('test', TEST_INFO))
@ifExternalFDAvailable(Aria2cFD)
class TestAria2cFD(unittest.TestCase):
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = Aria2cFD(ydl, {})
downloader._make_cmd('test', TEST_INFO)
self.assertFalse(hasattr(downloader, '_cookies_tempfile'))
# Test cookiejar tempfile arg is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
cmd = downloader._make_cmd('test', TEST_INFO)
self.assertIn('--load-cookies=%s' % downloader._cookies_tempfile, cmd)
@ifExternalFDAvailable(FFmpegFD)
class TestFFmpegFD(unittest.TestCase):
_args = []
def _test_cmd(self, args):
self._args = args
def test_make_cmd(self):
with FakeYDL() as ydl:
downloader = FFmpegFD(ydl, {})
downloader._debug_cmd = self._test_cmd
info_dict = TEST_INFO.copy()
info_dict['ext'] = 'mp4'
downloader._call_downloader('test', info_dict)
self.assertEqual(self._args, [
'ffmpeg', '-y', '-i', 'http://www.example.com/',
'-c', 'copy', '-f', 'mp4', 'file:test'])
# Test cookies arg is added
ydl.cookiejar.set_cookie(cookiejar_Cookie(**TEST_COOKIE))
downloader._call_downloader('test', info_dict)
self.assertEqual(self._args, [
'ffmpeg', '-y', '-cookies', TEST_COOKIE_VALUE + '; path=/; domain=.example.com;\r\n',
'-i', 'http://www.example.com/', '-c', 'copy', '-f', 'mp4', 'file:test'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -24,21 +24,24 @@ except AttributeError:
class TestExecution(unittest.TestCase): class TestExecution(unittest.TestCase):
def setUp(self):
self.module = 'youtube_dl'
if sys.version_info < (2, 7):
self.module += '.__main__'
def test_import(self): def test_import(self):
subprocess.check_call([sys.executable, '-c', 'import youtube_dl'], cwd=rootDir) subprocess.check_call([sys.executable, '-c', 'import youtube_dl'], cwd=rootDir)
@unittest.skipIf(sys.version_info < (2, 7), 'Python 2.6 doesn\'t support package execution')
def test_module_exec(self): def test_module_exec(self):
subprocess.check_call([sys.executable, '-m', 'youtube_dl', '--version'], cwd=rootDir, stdout=_DEV_NULL) subprocess.check_call([sys.executable, '-m', self.module, '--version'], cwd=rootDir, stdout=_DEV_NULL)
def test_main_exec(self): def test_main_exec(self):
subprocess.check_call([sys.executable, os.path.normpath('youtube_dl/__main__.py'), '--version'], cwd=rootDir, stdout=_DEV_NULL) subprocess.check_call([sys.executable, os.path.normpath('youtube_dl/__main__.py'), '--version'], cwd=rootDir, stdout=_DEV_NULL)
@unittest.skipIf(sys.version_info < (2, 7), 'Python 2.6 doesn\'t support package execution')
def test_cmdline_umlauts(self): def test_cmdline_umlauts(self):
os.environ['PYTHONIOENCODING'] = 'utf-8' os.environ['PYTHONIOENCODING'] = 'utf-8'
p = subprocess.Popen( p = subprocess.Popen(
[sys.executable, os.path.normpath('youtube_dl/__main__.py'), encodeArgument('ä'), '--version'], [sys.executable, '-m', self.module, encodeArgument('ä'), '--version'],
cwd=rootDir, stdout=_DEV_NULL, stderr=subprocess.PIPE) cwd=rootDir, stdout=_DEV_NULL, stderr=subprocess.PIPE)
_, stderr = p.communicate() _, stderr = p.communicate()
self.assertFalse(stderr) self.assertFalse(stderr)

View File

@ -8,33 +8,161 @@ import sys
import unittest import unittest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import contextlib
import gzip
import io
import ssl
import tempfile
import threading
import zlib
# avoid deprecated alias assertRaisesRegexp
if hasattr(unittest.TestCase, 'assertRaisesRegex'):
unittest.TestCase.assertRaisesRegexp = unittest.TestCase.assertRaisesRegex
try:
import brotli
except ImportError:
brotli = None
try:
from urllib.request import pathname2url
except ImportError:
from urllib import pathname2url
from youtube_dl.compat import (
compat_http_cookiejar_Cookie,
compat_http_server,
compat_str as str,
compat_urllib_error,
compat_urllib_HTTPError,
compat_urllib_parse,
compat_urllib_request,
)
from youtube_dl.utils import (
sanitized_Request,
urlencode_postdata,
)
from test.helper import ( from test.helper import (
FakeYDL,
FakeLogger, FakeLogger,
http_server_port, http_server_port,
) )
from youtube_dl import YoutubeDL from youtube_dl import YoutubeDL
from youtube_dl.compat import compat_http_server, compat_urllib_request
import ssl
import threading
TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_DIR = os.path.dirname(os.path.abspath(__file__))
class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler): class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
# work-around old/new -style class inheritance
def super(self, meth_name, *args, **kwargs):
from types import MethodType
try:
super()
fn = lambda s, m, *a, **k: getattr(super(), m)(*a, **k)
except TypeError:
fn = lambda s, m, *a, **k: getattr(compat_http_server.BaseHTTPRequestHandler, m)(s, *a, **k)
self.super = MethodType(fn, self)
return self.super(meth_name, *args, **kwargs)
def log_message(self, format, *args): def log_message(self, format, *args):
pass pass
def do_GET(self): def _headers(self):
if self.path == '/video.html': payload = str(self.headers).encode('utf-8')
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', str(len(payload)))
self.end_headers()
self.wfile.write(payload)
def _redirect(self):
self.send_response(int(self.path[len('/redirect_'):]))
self.send_header('Location', '/method')
self.send_header('Content-Length', '0')
self.end_headers()
def _method(self, method, payload=None):
self.send_response(200)
self.send_header('Content-Length', str(len(payload or '')))
self.send_header('Method', method)
self.end_headers()
if payload:
self.wfile.write(payload)
def _status(self, status):
payload = '<html>{0} NOT FOUND</html>'.format(status).encode('utf-8')
self.send_response(int(status))
self.send_header('Content-Type', 'text/html; charset=utf-8') self.send_header('Content-Type', 'text/html; charset=utf-8')
self.send_header('Content-Length', str(len(payload)))
self.end_headers() self.end_headers()
self.wfile.write(b'<html><video src="/vid.mp4" /></html>') self.wfile.write(payload)
def _read_data(self):
if 'Content-Length' in self.headers:
return self.rfile.read(int(self.headers['Content-Length']))
def _test_url(self, path, host='127.0.0.1', scheme='http', port=None):
return '{0}://{1}:{2}/{3}'.format(
scheme, host,
port if port is not None
else http_server_port(self.server), path)
def do_POST(self):
data = self._read_data()
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('POST', data)
elif self.path.startswith('/headers'):
self._headers()
else:
self._status(404)
def do_HEAD(self):
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('HEAD')
else:
self._status(404)
def do_PUT(self):
data = self._read_data()
if self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('PUT', data)
else:
self._status(404)
def do_GET(self):
def respond(payload=b'<html><video src="/vid.mp4" /></html>',
payload_type='text/html; charset=utf-8',
payload_encoding=None,
resp_code=200):
self.send_response(resp_code)
self.send_header('Content-Type', payload_type)
if payload_encoding:
self.send_header('Content-Encoding', payload_encoding)
self.send_header('Content-Length', str(len(payload))) # required for persistent connections
self.end_headers()
self.wfile.write(payload)
def gzip_compress(p):
buf = io.BytesIO()
with contextlib.closing(gzip.GzipFile(fileobj=buf, mode='wb')) as f:
f.write(p)
return buf.getvalue()
if self.path == '/video.html':
respond()
elif self.path == '/vid.mp4': elif self.path == '/vid.mp4':
self.send_response(200) respond(b'\x00\x00\x00\x00\x20\x66\x74[video]', 'video/mp4')
self.send_header('Content-Type', 'video/mp4')
self.end_headers()
self.wfile.write(b'\x00\x00\x00\x00\x20\x66\x74[video]')
elif self.path == '/302': elif self.path == '/302':
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
# XXX: Python 3 http server does not allow non-ASCII header values # XXX: Python 3 http server does not allow non-ASCII header values
@ -42,60 +170,316 @@ class HTTPTestRequestHandler(compat_http_server.BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
return return
new_url = 'http://127.0.0.1:%d/中文.html' % http_server_port(self.server) new_url = self._test_url('中文.html')
self.send_response(302) self.send_response(302)
self.send_header(b'Location', new_url.encode('utf-8')) self.send_header(b'Location', new_url.encode('utf-8'))
self.end_headers() self.end_headers()
elif self.path == '/%E4%B8%AD%E6%96%87.html': elif self.path == '/%E4%B8%AD%E6%96%87.html':
self.send_response(200) respond()
self.send_header('Content-Type', 'text/html; charset=utf-8') elif self.path == '/%c7%9f':
respond()
elif self.path.startswith('/redirect_'):
self._redirect()
elif self.path.startswith('/method'):
self._method('GET')
elif self.path.startswith('/headers'):
self._headers()
elif self.path.startswith('/308-to-headers'):
self.send_response(308)
self.send_header('Location', '/headers')
self.send_header('Content-Length', '0')
self.end_headers() self.end_headers()
self.wfile.write(b'<html><video src="/vid.mp4" /></html>') elif self.path == '/trailing_garbage':
payload = b'<html><video src="/vid.mp4" /></html>'
compressed = gzip_compress(payload) + b'trailing garbage'
respond(compressed, payload_encoding='gzip')
elif self.path == '/302-non-ascii-redirect':
new_url = self._test_url('中文.html')
# actually respond with permanent redirect
self.send_response(301)
self.send_header('Location', new_url)
self.send_header('Content-Length', '0')
self.end_headers()
elif self.path == '/content-encoding':
encodings = self.headers.get('ytdl-encoding', '')
payload = b'<html><video src="/vid.mp4" /></html>'
for encoding in filter(None, (e.strip() for e in encodings.split(','))):
if encoding == 'br' and brotli:
payload = brotli.compress(payload)
elif encoding == 'gzip':
payload = gzip_compress(payload)
elif encoding == 'deflate':
payload = zlib.compress(payload)
elif encoding == 'unsupported':
payload = b'raw'
break
else: else:
assert False self._status(415)
return
respond(payload, payload_encoding=encodings)
else:
self._status(404)
def send_header(self, keyword, value):
"""
Forcibly allow HTTP server to send non percent-encoded non-ASCII characters in headers.
This is against what is defined in RFC 3986: but we need to test that we support this
since some sites incorrectly do this.
"""
if keyword.lower() == 'connection':
return self.super('send_header', keyword, value)
if not hasattr(self, '_headers_buffer'):
self._headers_buffer = []
self._headers_buffer.append('{0}: {1}\r\n'.format(keyword, value).encode('utf-8'))
def end_headers(self):
if hasattr(self, '_headers_buffer'):
self.wfile.write(b''.join(self._headers_buffer))
self._headers_buffer = []
self.super('end_headers')
class TestHTTP(unittest.TestCase): class TestHTTP(unittest.TestCase):
def setUp(self): def setUp(self):
self.httpd = compat_http_server.HTTPServer( # HTTP server
self.http_httpd = compat_http_server.HTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler) ('127.0.0.1', 0), HTTPTestRequestHandler)
self.port = http_server_port(self.httpd) self.http_port = http_server_port(self.http_httpd)
self.server_thread = threading.Thread(target=self.httpd.serve_forever)
self.server_thread.daemon = True self.http_server_thread = threading.Thread(target=self.http_httpd.serve_forever)
self.server_thread.start() self.http_server_thread.daemon = True
self.http_server_thread.start()
try:
from http.server import ThreadingHTTPServer
except ImportError:
try:
from socketserver import ThreadingMixIn
except ImportError:
from SocketServer import ThreadingMixIn
class ThreadingHTTPServer(ThreadingMixIn, compat_http_server.HTTPServer):
pass
# HTTPS server
certfn = os.path.join(TEST_DIR, 'testcert.pem')
self.https_httpd = ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
try:
sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslctx.verify_mode = ssl.CERT_NONE
sslctx.check_hostname = False
sslctx.load_cert_chain(certfn, None)
self.https_httpd.socket = sslctx.wrap_socket(
self.https_httpd.socket, server_side=True)
except AttributeError:
self.https_httpd.socket = ssl.wrap_socket(
self.https_httpd.socket, certfile=certfn, server_side=True)
self.https_port = http_server_port(self.https_httpd)
self.https_server_thread = threading.Thread(target=self.https_httpd.serve_forever)
self.https_server_thread.daemon = True
self.https_server_thread.start()
def tearDown(self):
def closer(svr):
def _closer():
svr.shutdown()
svr.server_close()
return _closer
shutdown_thread = threading.Thread(target=closer(self.http_httpd))
shutdown_thread.start()
self.http_server_thread.join(2.0)
shutdown_thread = threading.Thread(target=closer(self.https_httpd))
shutdown_thread.start()
self.https_server_thread.join(2.0)
def _test_url(self, path, host='127.0.0.1', scheme='http', port=None):
return '{0}://{1}:{2}/{3}'.format(
scheme, host,
port if port is not None
else self.https_port if scheme == 'https'
else self.http_port, path)
@unittest.skipUnless(
sys.version_info >= (3, 2)
or (sys.version_info[0] == 2 and sys.version_info[1:] >= (7, 9)),
'No support for certificate check in SSL')
def test_nocheckcertificate(self):
with FakeYDL({'logger': FakeLogger()}) as ydl:
with self.assertRaises(compat_urllib_error.URLError):
ydl.urlopen(sanitized_Request(self._test_url('headers', scheme='https')))
with FakeYDL({'logger': FakeLogger(), 'nocheckcertificate': True}) as ydl:
r = ydl.urlopen(sanitized_Request(self._test_url('headers', scheme='https')))
self.assertEqual(r.getcode(), 200)
r.close()
def test_percent_encode(self):
with FakeYDL() as ydl:
# Unicode characters should be encoded with uppercase percent-encoding
res = ydl.urlopen(sanitized_Request(self._test_url('中文.html')))
self.assertEqual(res.getcode(), 200)
res.close()
# don't normalize existing percent encodings
res = ydl.urlopen(sanitized_Request(self._test_url('%c7%9f')))
self.assertEqual(res.getcode(), 200)
res.close()
def test_unicode_path_redirection(self): def test_unicode_path_redirection(self):
# XXX: Python 3 http server does not allow non-ASCII header values with FakeYDL() as ydl:
if sys.version_info[0] == 3: r = ydl.urlopen(sanitized_Request(self._test_url('302-non-ascii-redirect')))
return self.assertEqual(r.url, self._test_url('%E4%B8%AD%E6%96%87.html'))
r.close()
ydl = YoutubeDL({'logger': FakeLogger()}) def test_redirect(self):
r = ydl.extract_info('http://127.0.0.1:%d/302' % self.port) with FakeYDL() as ydl:
self.assertEqual(r['entries'][0]['url'], 'http://127.0.0.1:%d/vid.mp4' % self.port) def do_req(redirect_status, method, check_no_content=False):
data = b'testdata' if method in ('POST', 'PUT') else None
res = ydl.urlopen(sanitized_Request(
self._test_url('redirect_{0}'.format(redirect_status)),
method=method, data=data))
if check_no_content:
self.assertNotIn('Content-Type', res.headers)
return res.read().decode('utf-8'), res.headers.get('method', '')
# A 303 must either use GET or HEAD for subsequent request
self.assertEqual(do_req(303, 'POST'), ('', 'GET'))
self.assertEqual(do_req(303, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(303, 'PUT'), ('', 'GET'))
class TestHTTPS(unittest.TestCase): # 301 and 302 turn POST only into a GET, with no Content-Type
def setUp(self): self.assertEqual(do_req(301, 'POST', True), ('', 'GET'))
certfn = os.path.join(TEST_DIR, 'testcert.pem') self.assertEqual(do_req(301, 'HEAD'), ('', 'HEAD'))
self.httpd = compat_http_server.HTTPServer( self.assertEqual(do_req(302, 'POST', True), ('', 'GET'))
('127.0.0.1', 0), HTTPTestRequestHandler) self.assertEqual(do_req(302, 'HEAD'), ('', 'HEAD'))
self.httpd.socket = ssl.wrap_socket(
self.httpd.socket, certfile=certfn, server_side=True)
self.port = http_server_port(self.httpd)
self.server_thread = threading.Thread(target=self.httpd.serve_forever)
self.server_thread.daemon = True
self.server_thread.start()
def test_nocheckcertificate(self): self.assertEqual(do_req(301, 'PUT'), ('testdata', 'PUT'))
if sys.version_info >= (2, 7, 9): # No certificate checking anyways self.assertEqual(do_req(302, 'PUT'), ('testdata', 'PUT'))
ydl = YoutubeDL({'logger': FakeLogger()})
self.assertRaises(
Exception,
ydl.extract_info, 'https://127.0.0.1:%d/video.html' % self.port)
ydl = YoutubeDL({'logger': FakeLogger(), 'nocheckcertificate': True}) # 307 and 308 should not change method
r = ydl.extract_info('https://127.0.0.1:%d/video.html' % self.port) for m in ('POST', 'PUT'):
self.assertEqual(r['entries'][0]['url'], 'https://127.0.0.1:%d/vid.mp4' % self.port) self.assertEqual(do_req(307, m), ('testdata', m))
self.assertEqual(do_req(308, m), ('testdata', m))
self.assertEqual(do_req(307, 'HEAD'), ('', 'HEAD'))
self.assertEqual(do_req(308, 'HEAD'), ('', 'HEAD'))
# These should not redirect and instead raise an HTTPError
for code in (300, 304, 305, 306):
with self.assertRaises(compat_urllib_HTTPError):
do_req(code, 'GET')
def test_content_type(self):
# https://github.com/yt-dlp/yt-dlp/commit/379a4f161d4ad3e40932dcf5aca6e6fb9715ab28
with FakeYDL({'nocheckcertificate': True}) as ydl:
# method should be auto-detected as POST
r = sanitized_Request(self._test_url('headers', scheme='https'), data=urlencode_postdata({'test': 'test'}))
headers = ydl.urlopen(r).read().decode('utf-8')
self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
# test http
r = sanitized_Request(self._test_url('headers'), data=urlencode_postdata({'test': 'test'}))
headers = ydl.urlopen(r).read().decode('utf-8')
self.assertIn('Content-Type: application/x-www-form-urlencoded', headers)
def test_cookiejar(self):
with FakeYDL() as ydl:
ydl.cookiejar.set_cookie(compat_http_cookiejar_Cookie(
0, 'test', 'ytdl', None, False, '127.0.0.1', True,
False, '/headers', True, False, None, False, None, None, {}))
data = ydl.urlopen(sanitized_Request(
self._test_url('headers'))).read().decode('utf-8')
self.assertIn('Cookie: test=ytdl', data)
def test_passed_cookie_header(self):
# We should accept a Cookie header being passed as in normal headers and handle it appropriately.
with FakeYDL() as ydl:
# Specified Cookie header should be used
res = ydl.urlopen(sanitized_Request(
self._test_url('headers'), headers={'Cookie': 'test=test'})).read().decode('utf-8')
self.assertIn('Cookie: test=test', res)
# Specified Cookie header should be removed on any redirect
res = ydl.urlopen(sanitized_Request(
self._test_url('308-to-headers'), headers={'Cookie': 'test=test'})).read().decode('utf-8')
self.assertNotIn('Cookie: test=test', res)
# Specified Cookie header should override global cookiejar for that request
ydl.cookiejar.set_cookie(compat_http_cookiejar_Cookie(
0, 'test', 'ytdlp', None, False, '127.0.0.1', True,
False, '/headers', True, False, None, False, None, None, {}))
data = ydl.urlopen(sanitized_Request(
self._test_url('headers'), headers={'Cookie': 'test=test'})).read().decode('utf-8')
self.assertNotIn('Cookie: test=ytdlp', data)
self.assertIn('Cookie: test=test', data)
def test_no_compression_compat_header(self):
with FakeYDL() as ydl:
data = ydl.urlopen(
sanitized_Request(
self._test_url('headers'),
headers={'Youtubedl-no-compression': True})).read()
self.assertIn(b'Accept-Encoding: identity', data)
self.assertNotIn(b'youtubedl-no-compression', data.lower())
def test_gzip_trailing_garbage(self):
# https://github.com/ytdl-org/youtube-dl/commit/aa3e950764337ef9800c936f4de89b31c00dfcf5
# https://github.com/ytdl-org/youtube-dl/commit/6f2ec15cee79d35dba065677cad9da7491ec6e6f
with FakeYDL() as ydl:
data = ydl.urlopen(sanitized_Request(self._test_url('trailing_garbage'))).read().decode('utf-8')
self.assertEqual(data, '<html><video src="/vid.mp4" /></html>')
def __test_compression(self, encoding):
with FakeYDL() as ydl:
res = ydl.urlopen(
sanitized_Request(
self._test_url('content-encoding'),
headers={'ytdl-encoding': encoding}))
self.assertEqual(res.headers.get('Content-Encoding'), encoding)
self.assertEqual(res.read(), b'<html><video src="/vid.mp4" /></html>')
@unittest.skipUnless(brotli, 'brotli support is not installed')
@unittest.expectedFailure
def test_brotli(self):
self.__test_compression('br')
@unittest.expectedFailure
def test_deflate(self):
self.__test_compression('deflate')
@unittest.expectedFailure
def test_gzip(self):
self.__test_compression('gzip')
@unittest.expectedFailure # not yet implemented
def test_multiple_encodings(self):
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.4
with FakeYDL() as ydl:
for pair in ('gzip,deflate', 'deflate, gzip', 'gzip, gzip', 'deflate, deflate'):
res = ydl.urlopen(
sanitized_Request(
self._test_url('content-encoding'),
headers={'ytdl-encoding': pair}))
self.assertEqual(res.headers.get('Content-Encoding'), pair)
self.assertEqual(res.read(), b'<html><video src="/vid.mp4" /></html>')
def test_unsupported_encoding(self):
# it should return the raw content
with FakeYDL() as ydl:
res = ydl.urlopen(
sanitized_Request(
self._test_url('content-encoding'),
headers={'ytdl-encoding': 'unsupported'}))
self.assertEqual(res.headers.get('Content-Encoding'), 'unsupported')
self.assertEqual(res.read(), b'raw')
def _build_proxy_handler(name): def _build_proxy_handler(name):
@ -109,7 +493,7 @@ def _build_proxy_handler(name):
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'text/plain; charset=utf-8') self.send_header('Content-Type', 'text/plain; charset=utf-8')
self.end_headers() self.end_headers()
self.wfile.write('{self.proxy_name}: {self.path}'.format(self=self).encode('utf-8')) self.wfile.write('{0}: {1}'.format(self.proxy_name, self.path).encode('utf-8'))
return HTTPTestRequestHandler return HTTPTestRequestHandler
@ -129,10 +513,30 @@ class TestProxy(unittest.TestCase):
self.geo_proxy_thread.daemon = True self.geo_proxy_thread.daemon = True
self.geo_proxy_thread.start() self.geo_proxy_thread.start()
def tearDown(self):
def closer(svr):
def _closer():
svr.shutdown()
svr.server_close()
return _closer
shutdown_thread = threading.Thread(target=closer(self.proxy))
shutdown_thread.start()
self.proxy_thread.join(2.0)
shutdown_thread = threading.Thread(target=closer(self.geo_proxy))
shutdown_thread.start()
self.geo_proxy_thread.join(2.0)
def _test_proxy(self, host='127.0.0.1', port=None):
return '{0}:{1}'.format(
host, port if port is not None else self.port)
def test_proxy(self): def test_proxy(self):
geo_proxy = '127.0.0.1:{0}'.format(self.geo_port) geo_proxy = self._test_proxy(port=self.geo_port)
ydl = YoutubeDL({ ydl = YoutubeDL({
'proxy': '127.0.0.1:{0}'.format(self.port), 'proxy': self._test_proxy(),
'geo_verification_proxy': geo_proxy, 'geo_verification_proxy': geo_proxy,
}) })
url = 'http://foo.com/bar' url = 'http://foo.com/bar'
@ -146,7 +550,7 @@ class TestProxy(unittest.TestCase):
def test_proxy_with_idn(self): def test_proxy_with_idn(self):
ydl = YoutubeDL({ ydl = YoutubeDL({
'proxy': '127.0.0.1:{0}'.format(self.port), 'proxy': self._test_proxy(),
}) })
url = 'http://中文.tw/' url = 'http://中文.tw/'
response = ydl.urlopen(url).read().decode('utf-8') response = ydl.urlopen(url).read().decode('utf-8')
@ -154,5 +558,25 @@ class TestProxy(unittest.TestCase):
self.assertEqual(response, 'normal: http://xn--fiq228c.tw/') self.assertEqual(response, 'normal: http://xn--fiq228c.tw/')
class TestFileURL(unittest.TestCase):
# See https://github.com/ytdl-org/youtube-dl/issues/8227
def test_file_urls(self):
tf = tempfile.NamedTemporaryFile(delete=False)
tf.write(b'foobar')
tf.close()
url = compat_urllib_parse.urljoin('file://', pathname2url(tf.name))
with FakeYDL() as ydl:
self.assertRaisesRegexp(
compat_urllib_error.URLError, 'file:// scheme is explicitly disabled in youtube-dl for security reasons', ydl.urlopen, url)
# not yet implemented
"""
with FakeYDL({'enable_file_urls': True}) as ydl:
res = ydl.urlopen(url)
self.assertEqual(res.read(), b'foobar')
res.close()
"""
os.unlink(tf.name)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -79,10 +79,12 @@ from youtube_dl.utils import (
rot47, rot47,
shell_quote, shell_quote,
smuggle_url, smuggle_url,
str_or_none,
str_to_int, str_to_int,
strip_jsonp, strip_jsonp,
strip_or_none, strip_or_none,
subtitles_filename, subtitles_filename,
T,
timeconvert, timeconvert,
traverse_obj, traverse_obj,
try_call, try_call,
@ -1566,6 +1568,7 @@ Line 1
self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam') self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam')
def test_traverse_obj(self): def test_traverse_obj(self):
str = compat_str
_TEST_DATA = { _TEST_DATA = {
100: 100, 100: 100,
1.2: 1.2, 1.2: 1.2,
@ -1598,8 +1601,8 @@ Line 1
# Test Ellipsis behavior # Test Ellipsis behavior
self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
(item for item in _TEST_DATA.values() if item is not None), (item for item in _TEST_DATA.values() if item not in (None, {})),
msg='`...` should give all values except `None`') msg='`...` should give all non discarded values')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
msg='`...` selection for dicts should select all values') msg='`...` selection for dicts should select all values')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
@ -1607,13 +1610,51 @@ Line 1
msg='nested `...` queries should work') msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4),
msg='`...` query result should be flattened') msg='`...` query result should be flattened')
self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),
msg='`...` should accept iterables')
# Test function as key # Test function as key
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
[_TEST_DATA['urls']], [_TEST_DATA['urls']],
msg='function as query key should perform a filter based on (key, value)') msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], compat_str)), ('str',), self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)),
msg='exceptions in the query function should be caught') msg='exceptions in the query function should be catched')
self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
msg='function key should accept iterables')
if __debug__:
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a: Ellipsis)
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a, b, c: Ellipsis)
# Test set as key (transformation/type, like `expected_type`)
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper), )), ['STR'],
msg='Function in set should be a transformation')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str))), ['str'],
msg='Type in set should be a type filter')
self.assertEqual(traverse_obj(_TEST_DATA, T(dict)), _TEST_DATA,
msg='A single set should be wrapped into a path')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper))), ['STR'],
msg='Transformation function should not raise')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))),
[item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
msg='Function in set should be a transformation')
if __debug__:
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, set())
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, set((str.upper, str)))
# Test `slice` as a key
_SLICE_DATA = [0, 1, 2, 3, 4]
self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None,
msg='slice on a dictionary should not throw')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2],
msg='slice key should apply slice to sequence')
# Test alternative paths # Test alternative paths
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
@ -1659,15 +1700,23 @@ Line 1
{0: ['https://www.example.com/1', 'https://www.example.com/0']}, {0: ['https://www.example.com/1', 'https://www.example.com/0']},
msg='triple nesting in dict path should be treated as branches') msg='triple nesting in dict path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
msg='remove `None` values when dict key') msg='remove `None` values when top level dict key fails')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
msg='do not remove `None` values if `default`') msg='use `default` if key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
msg='do not remove empty values when dict key') msg='remove empty values when dict key')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: {}}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis},
msg='do not remove empty values when dict key and a default') msg='use `default` when dict key and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {0: []}, self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
msg='if branch in dict key not successful, return `[]`') msg='remove empty values when nested dict key fails')
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
msg='default to dict if pruned')
self.assertEqual(traverse_obj(None, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
msg='default to dict if pruned and default is given')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=Ellipsis), {0: {0: Ellipsis}},
msg='use nested `default` when nested dict key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {},
msg='remove key if branch in dict key not successful')
# Testing default parameter behavior # Testing default parameter behavior
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
@ -1691,20 +1740,55 @@ Line 1
msg='if branched but not successful return `[]`, not `default`') msg='if branched but not successful return `[]`, not `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [], self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [],
msg='if branched but object is empty return `[]`, not `default`') msg='if branched but object is empty return `[]`, not `default`')
self.assertEqual(traverse_obj(None, Ellipsis), [],
msg='if branched but object is `None` return `[]`, not `default`')
self.assertEqual(traverse_obj({0: None}, (0, Ellipsis)), [],
msg='if branched but state is `None` return `[]`, not `default`')
branching_paths = [
('fail', Ellipsis),
(Ellipsis, 'fail'),
100 * ('fail',) + (Ellipsis,),
(Ellipsis,) + 100 * ('fail',),
]
for branching_path in branching_paths:
self.assertEqual(traverse_obj({}, branching_path), [],
msg='if branched but state is `None`, return `[]` (not `default`)')
self.assertEqual(traverse_obj({}, 'fail', branching_path), [],
msg='if branching in last alternative and previous did not match, return `[]` (not `default`)')
self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x',
msg='if branching in last alternative and previous did match, return single value')
self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x',
msg='if branching in first alternative and non-branching path does match, return single value')
self.assertEqual(traverse_obj({}, branching_path, 'fail'), None,
msg='if branching in first alternative and non-branching path does not match, return `default`')
# Testing expected_type behavior # Testing expected_type behavior
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=compat_str), 'str', self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
msg='accept matching `expected_type` type') 'str', msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
msg='reject non matching `expected_type` type') None, msg='reject non matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: compat_str(x)), '0', self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
msg='transform type using type function') '0', msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
expected_type=lambda _: 1 / 0), None, None, msg='wrap expected_type function in try_call')
msg='wrap expected_type function in try_call') self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=str),
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=compat_str), ['str'], ['str'], msg='eliminate items that expected_type fails on')
msg='eliminate items that expected_type fails on') self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int),
{0: 100}, msg='type as expected_type should filter dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
{0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int),
1, msg='expected_type should not filter non final dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
{0: {0: 100}}, msg='expected_type should transform deep dict values')
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)),
[{0: Ellipsis}, {0: Ellipsis}], msg='expected_type should transform branched dict values')
self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int),
[4], msg='expected_type regression for type matching in tuple branching')
self.assertEqual(traverse_obj(_TEST_DATA, ['data', Ellipsis], expected_type=int),
[], msg='expected_type regression for type matching in dict result')
# Test get_all behavior # Test get_all behavior
_GET_ALL_DATA = {'key': [0, 1, 2]} _GET_ALL_DATA = {'key': [0, 1, 2]}
@ -1749,14 +1833,23 @@ Line 1
_traverse_string=True), '.', _traverse_string=True), '.',
msg='traverse into converted data if `traverse_string`') msg='traverse into converted data if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis),
_traverse_string=True), list('str'), _traverse_string=True), 'str',
msg='`...` branching into string should result in list') msg='`...` should result in string (same value) if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
_traverse_string=True), 'sr',
msg='`slice` should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
_traverse_string=True), 'str',
msg='function should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
_traverse_string=True), ['s', 'r'], _traverse_string=True), ['s', 'r'],
msg='branching into string should result in list') msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), self.assertEqual(traverse_obj({}, (0, Ellipsis), _traverse_string=True), [],
_traverse_string=True), list('str'), msg='branching should result in list if `traverse_string`')
msg='function branching into string should result in list') self.assertEqual(traverse_obj({}, (0, lambda x, y: True), _traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, slice(1)), _traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
# Test is_user_input behavior # Test is_user_input behavior
_IS_USER_INPUT_DATA = {'range8': list(range(8))} _IS_USER_INPUT_DATA = {'range8': list(range(8))}
@ -1793,6 +1886,8 @@ Line 1
msg='failing str key on a `re.Match` should return `default`') msg='failing str key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, 8), None, self.assertEqual(traverse_obj(mobj, 8), None,
msg='failing int key on a `re.Match` should return `default`') msg='failing int key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
msg='function on a `re.Match` should give group name as well')
def test_get_first(self): def test_get_first(self):
self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')

View File

@ -5,6 +5,7 @@ from __future__ import absolute_import, unicode_literals
import collections import collections
import contextlib import contextlib
import copy
import datetime import datetime
import errno import errno
import fileinput import fileinput
@ -34,10 +35,12 @@ from string import ascii_letters
from .compat import ( from .compat import (
compat_basestring, compat_basestring,
compat_cookiejar, compat_collections_chain_map as ChainMap,
compat_filter as filter, compat_filter as filter,
compat_get_terminal_size, compat_get_terminal_size,
compat_http_client, compat_http_client,
compat_http_cookiejar_Cookie,
compat_http_cookies_SimpleCookie,
compat_integer_types, compat_integer_types,
compat_kwargs, compat_kwargs,
compat_map as map, compat_map as map,
@ -53,6 +56,7 @@ from .compat import (
from .utils import ( from .utils import (
age_restricted, age_restricted,
args_to_str, args_to_str,
bug_reports_message,
ContentTooShortError, ContentTooShortError,
date_from_str, date_from_str,
DateRange, DateRange,
@ -97,6 +101,7 @@ from .utils import (
std_headers, std_headers,
str_or_none, str_or_none,
subtitles_filename, subtitles_filename,
traverse_obj,
UnavailableVideoError, UnavailableVideoError,
url_basename, url_basename,
version_tuple, version_tuple,
@ -376,6 +381,9 @@ class YoutubeDL(object):
self.params.update(params) self.params.update(params)
self.cache = Cache(self) self.cache = Cache(self)
self._header_cookies = []
self._load_cookies_from_headers(self.params.get('http_headers'))
def check_deprecated(param, option, suggestion): def check_deprecated(param, option, suggestion):
if self.params.get(param) is not None: if self.params.get(param) is not None:
self.report_warning( self.report_warning(
@ -582,7 +590,7 @@ class YoutubeDL(object):
if self.params.get('cookiefile') is not None: if self.params.get('cookiefile') is not None:
self.cookiejar.save(ignore_discard=True, ignore_expires=True) self.cookiejar.save(ignore_discard=True, ignore_expires=True)
def trouble(self, message=None, tb=None): def trouble(self, *args, **kwargs):
"""Determine action to take when a download problem appears. """Determine action to take when a download problem appears.
Depending on if the downloader has been configured to ignore Depending on if the downloader has been configured to ignore
@ -591,6 +599,11 @@ class YoutubeDL(object):
tb, if given, is additional traceback information. tb, if given, is additional traceback information.
""" """
# message=None, tb=None, is_error=True
message = args[0] if len(args) > 0 else kwargs.get('message', None)
tb = args[1] if len(args) > 1 else kwargs.get('tb', None)
is_error = args[2] if len(args) > 2 else kwargs.get('is_error', True)
if message is not None: if message is not None:
self.to_stderr(message) self.to_stderr(message)
if self.params.get('verbose'): if self.params.get('verbose'):
@ -603,7 +616,10 @@ class YoutubeDL(object):
else: else:
tb_data = traceback.format_list(traceback.extract_stack()) tb_data = traceback.format_list(traceback.extract_stack())
tb = ''.join(tb_data) tb = ''.join(tb_data)
if tb:
self.to_stderr(tb) self.to_stderr(tb)
if not is_error:
return
if not self.params.get('ignoreerrors', False): if not self.params.get('ignoreerrors', False):
if sys.exc_info()[0] and hasattr(sys.exc_info()[1], 'exc_info') and sys.exc_info()[1].exc_info[0]: if sys.exc_info()[0] and hasattr(sys.exc_info()[1], 'exc_info') and sys.exc_info()[1].exc_info[0]:
exc_info = sys.exc_info()[1].exc_info exc_info = sys.exc_info()[1].exc_info
@ -612,11 +628,18 @@ class YoutubeDL(object):
raise DownloadError(message, exc_info) raise DownloadError(message, exc_info)
self._download_retcode = 1 self._download_retcode = 1
def report_warning(self, message): def report_warning(self, message, only_once=False, _cache={}):
''' '''
Print the message to stderr, it will be prefixed with 'WARNING:' Print the message to stderr, it will be prefixed with 'WARNING:'
If stderr is a tty file the 'WARNING:' will be colored If stderr is a tty file the 'WARNING:' will be colored
''' '''
if only_once:
m_hash = hash((self, message))
m_cnt = _cache.setdefault(m_hash, 0)
_cache[m_hash] = m_cnt + 1
if m_cnt > 0:
return
if self.params.get('logger') is not None: if self.params.get('logger') is not None:
self.params['logger'].warning(message) self.params['logger'].warning(message)
else: else:
@ -629,7 +652,7 @@ class YoutubeDL(object):
warning_message = '%s %s' % (_msg_header, message) warning_message = '%s %s' % (_msg_header, message)
self.to_stderr(warning_message) self.to_stderr(warning_message)
def report_error(self, message, tb=None): def report_error(self, message, *args, **kwargs):
''' '''
Do the same as trouble, but prefixes the message with 'ERROR:', colored Do the same as trouble, but prefixes the message with 'ERROR:', colored
in red if stderr is a tty file. in red if stderr is a tty file.
@ -638,8 +661,18 @@ class YoutubeDL(object):
_msg_header = '\033[0;31mERROR:\033[0m' _msg_header = '\033[0;31mERROR:\033[0m'
else: else:
_msg_header = 'ERROR:' _msg_header = 'ERROR:'
error_message = '%s %s' % (_msg_header, message) kwargs['message'] = '%s %s' % (_msg_header, message)
self.trouble(error_message, tb) self.trouble(*args, **kwargs)
def report_unscoped_cookies(self, *args, **kwargs):
# message=None, tb=False, is_error=False
if len(args) <= 2:
kwargs.setdefault('is_error', False)
if len(args) <= 0:
kwargs.setdefault(
'message',
'Unscoped cookies are not allowed: please specify some sort of scoping')
self.report_error(*args, **kwargs)
def report_file_already_downloaded(self, file_name): def report_file_already_downloaded(self, file_name):
"""Report file has already been fully downloaded.""" """Report file has already been fully downloaded."""
@ -835,7 +868,7 @@ class YoutubeDL(object):
msg += '\nYou might want to use a VPN or a proxy server (with --proxy) to workaround.' msg += '\nYou might want to use a VPN or a proxy server (with --proxy) to workaround.'
self.report_error(msg) self.report_error(msg)
except ExtractorError as e: # An error we somewhat expected except ExtractorError as e: # An error we somewhat expected
self.report_error(compat_str(e), e.format_traceback()) self.report_error(compat_str(e), tb=e.format_traceback())
except MaxDownloadsReached: except MaxDownloadsReached:
raise raise
except Exception as e: except Exception as e:
@ -845,8 +878,83 @@ class YoutubeDL(object):
raise raise
return wrapper return wrapper
def _remove_cookie_header(self, http_headers):
"""Filters out `Cookie` header from an `http_headers` dict
The `Cookie` header is removed to prevent leaks as a result of unscoped cookies.
See: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-v8mc-9377-rwjj
@param http_headers An `http_headers` dict from which any `Cookie` header
should be removed, or None
"""
return dict(filter(lambda pair: pair[0].lower() != 'cookie', (http_headers or {}).items()))
def _load_cookies(self, data, **kwargs):
"""Loads cookies from a `Cookie` header
This tries to work around the security vulnerability of passing cookies to every domain.
@param data The Cookie header as a string to load the cookies from
@param autoscope If `False`, scope cookies using Set-Cookie syntax and error for cookie without domains
If `True`, save cookies for later to be stored in the jar with a limited scope
If a URL, save cookies in the jar with the domain of the URL
"""
# autoscope=True (kw-only)
autoscope = kwargs.get('autoscope', True)
for cookie in compat_http_cookies_SimpleCookie(data).values() if data else []:
if autoscope and any(cookie.values()):
raise ValueError('Invalid syntax in Cookie Header')
domain = cookie.get('domain') or ''
expiry = cookie.get('expires')
if expiry == '': # 0 is valid so we check for `''` explicitly
expiry = None
prepared_cookie = compat_http_cookiejar_Cookie(
cookie.get('version') or 0, cookie.key, cookie.value, None, False,
domain, True, True, cookie.get('path') or '', bool(cookie.get('path')),
bool(cookie.get('secure')), expiry, False, None, None, {})
if domain:
self.cookiejar.set_cookie(prepared_cookie)
elif autoscope is True:
self.report_warning(
'Passing cookies as a header is a potential security risk; '
'they will be scoped to the domain of the downloaded urls. '
'Please consider loading cookies from a file or browser instead.',
only_once=True)
self._header_cookies.append(prepared_cookie)
elif autoscope:
self.report_warning(
'The extractor result contains an unscoped cookie as an HTTP header. '
'If you are specifying an input URL, ' + bug_reports_message(),
only_once=True)
self._apply_header_cookies(autoscope, [prepared_cookie])
else:
self.report_unscoped_cookies()
def _load_cookies_from_headers(self, headers):
self._load_cookies(traverse_obj(headers, 'cookie', casesense=False))
def _apply_header_cookies(self, url, cookies=None):
"""This method applies stray header cookies to the provided url
This loads header cookies and scopes them to the domain provided in `url`.
While this is not ideal, it helps reduce the risk of them being sent to
an unintended destination.
"""
parsed = compat_urllib_parse.urlparse(url)
if not parsed.hostname:
return
for cookie in map(copy.copy, cookies or self._header_cookies):
cookie.domain = '.' + parsed.hostname
self.cookiejar.set_cookie(cookie)
@__handle_extraction_exceptions @__handle_extraction_exceptions
def __extract_info(self, url, ie, download, extra_info, process): def __extract_info(self, url, ie, download, extra_info, process):
# Compat with passing cookies in http headers
self._apply_header_cookies(url)
ie_result = ie.extract(url) ie_result = ie.extract(url)
if ie_result is None: # Finished already (backwards compatibility; listformats and friends should be moved here) if ie_result is None: # Finished already (backwards compatibility; listformats and friends should be moved here)
return return
@ -1443,23 +1551,45 @@ class YoutubeDL(object):
parsed_selector = _parse_format_selection(iter(TokenIterator(tokens))) parsed_selector = _parse_format_selection(iter(TokenIterator(tokens)))
return _build_selector_function(parsed_selector) return _build_selector_function(parsed_selector)
def _calc_headers(self, info_dict): def _calc_headers(self, info_dict, load_cookies=False):
res = std_headers.copy() if load_cookies: # For --load-info-json
# load cookies from http_headers in legacy info.json
self._load_cookies(traverse_obj(info_dict, ('http_headers', 'Cookie'), casesense=False),
autoscope=info_dict['url'])
# load scoped cookies from info.json
self._load_cookies(info_dict.get('cookies'), autoscope=False)
add_headers = info_dict.get('http_headers') cookies = self.cookiejar.get_cookies_for_url(info_dict['url'])
if add_headers:
res.update(add_headers)
cookies = self._calc_cookies(info_dict)
if cookies: if cookies:
res['Cookie'] = cookies # Make a string like name1=val1; attr1=a_val1; ...name2=val2; ...
# By convention a cookie name can't be a well-known attribute name
# so this syntax is unambiguous and can be parsed by (eg) SimpleCookie
encoder = compat_http_cookies_SimpleCookie()
values = []
attributes = (('Domain', '='), ('Path', '='), ('Secure',), ('Expires', '='), ('Version', '='))
attributes = tuple([x[0].lower()] + list(x) for x in attributes)
for cookie in cookies:
_, value = encoder.value_encode(cookie.value)
# Py 2 '' --> '', Py 3 '' --> '""'
if value == '':
value = '""'
values.append('='.join((cookie.name, value)))
for attr in attributes:
value = getattr(cookie, attr[0], None)
if value:
values.append('%s%s' % (''.join(attr[1:]), value if len(attr) == 3 else ''))
info_dict['cookies'] = '; '.join(values)
res = std_headers.copy()
res.update(info_dict.get('http_headers') or {})
res = self._remove_cookie_header(res)
if 'X-Forwarded-For' not in res: if 'X-Forwarded-For' not in res:
x_forwarded_for_ip = info_dict.get('__x_forwarded_for_ip') x_forwarded_for_ip = info_dict.get('__x_forwarded_for_ip')
if x_forwarded_for_ip: if x_forwarded_for_ip:
res['X-Forwarded-For'] = x_forwarded_for_ip res['X-Forwarded-For'] = x_forwarded_for_ip
return res return res or None
def _calc_cookies(self, info_dict): def _calc_cookies(self, info_dict):
pr = sanitized_Request(info_dict['url']) pr = sanitized_Request(info_dict['url'])
@ -1638,10 +1768,13 @@ class YoutubeDL(object):
format['protocol'] = determine_protocol(format) format['protocol'] = determine_protocol(format)
# Add HTTP headers, so that external programs can use them from the # Add HTTP headers, so that external programs can use them from the
# json output # json output
full_format_info = info_dict.copy() format['http_headers'] = self._calc_headers(ChainMap(format, info_dict), load_cookies=True)
full_format_info.update(format)
format['http_headers'] = self._calc_headers(full_format_info) # Safeguard against old/insecure infojson when using --load-info-json
# Remove private housekeeping stuff info_dict['http_headers'] = self._remove_cookie_header(
info_dict.get('http_headers') or {}) or None
# Remove private housekeeping stuff (copied to http_headers in _calc_headers())
if '__x_forwarded_for_ip' in info_dict: if '__x_forwarded_for_ip' in info_dict:
del info_dict['__x_forwarded_for_ip'] del info_dict['__x_forwarded_for_ip']
@ -1902,17 +2035,9 @@ class YoutubeDL(object):
(sub_lang, error_to_compat_str(err))) (sub_lang, error_to_compat_str(err)))
continue continue
if self.params.get('writeinfojson', False): self._write_info_json(
infofn = replace_extension(filename, 'info.json', info_dict.get('ext')) 'video description', info_dict,
if self.params.get('nooverwrites', False) and os.path.exists(encodeFilename(infofn)): replace_extension(filename, 'info.json', info_dict.get('ext')))
self.to_screen('[info] Video description metadata is already present')
else:
self.to_screen('[info] Writing video description metadata as JSON to: ' + infofn)
try:
write_json_file(self.filter_requested_info(info_dict), infofn)
except (OSError, IOError):
self.report_error('Cannot write metadata to JSON file ' + infofn)
return
self._write_thumbnails(info_dict, filename) self._write_thumbnails(info_dict, filename)
@ -1933,7 +2058,11 @@ class YoutubeDL(object):
fd.add_progress_hook(ph) fd.add_progress_hook(ph)
if self.params.get('verbose'): if self.params.get('verbose'):
self.to_screen('[debug] Invoking downloader on %r' % info.get('url')) self.to_screen('[debug] Invoking downloader on %r' % info.get('url'))
return fd.download(name, info)
new_info = dict((k, v) for k, v in info.items() if not k.startswith('__p'))
new_info['http_headers'] = self._calc_headers(new_info)
return fd.download(name, new_info)
if info_dict.get('requested_formats') is not None: if info_dict.get('requested_formats') is not None:
downloaded = [] downloaded = []
@ -2378,10 +2507,12 @@ class YoutubeDL(object):
self.get_encoding())) self.get_encoding()))
write_string(encoding_str, encoding=None) write_string(encoding_str, encoding=None)
self._write_string('[debug] youtube-dl version ' + __version__ + (' (single file build)\n' if ytdl_is_updateable() else '\n')) writeln_debug = lambda *s: self._write_string('[debug] %s\n' % (''.join(s), ))
writeln_debug('youtube-dl version ', __version__)
if _LAZY_LOADER: if _LAZY_LOADER:
self._write_string('[debug] Lazy loading extractors enabled\n') writeln_debug('Lazy loading extractors enabled')
writeln_debug = lambda *s: self._write_string('[debug] %s\n' % (''.join(s), )) # moved down for easier merge if ytdl_is_updateable():
writeln_debug('Single file build')
try: try:
sp = subprocess.Popen( sp = subprocess.Popen(
['git', 'rev-parse', '--short', 'HEAD'], ['git', 'rev-parse', '--short', 'HEAD'],
@ -2457,7 +2588,7 @@ class YoutubeDL(object):
opts_proxy = self.params.get('proxy') opts_proxy = self.params.get('proxy')
if opts_cookiefile is None: if opts_cookiefile is None:
self.cookiejar = compat_cookiejar.CookieJar() self.cookiejar = YoutubeDLCookieJar()
else: else:
opts_cookiefile = expand_path(opts_cookiefile) opts_cookiefile = expand_path(opts_cookiefile)
self.cookiejar = YoutubeDLCookieJar(opts_cookiefile) self.cookiejar = YoutubeDLCookieJar(opts_cookiefile)
@ -2518,6 +2649,28 @@ class YoutubeDL(object):
encoding = preferredencoding() encoding = preferredencoding()
return encoding return encoding
def _write_info_json(self, label, info_dict, infofn, overwrite=None):
if not self.params.get('writeinfojson', False):
return False
def msg(fmt, lbl):
return fmt % (lbl + ' metadata',)
if overwrite is None:
overwrite = not self.params.get('nooverwrites', False)
if not overwrite and os.path.exists(encodeFilename(infofn)):
self.to_screen(msg('[info] %s is already present', label.title()))
return 'exists'
else:
self.to_screen(msg('[info] Writing %s as JSON to: ' + infofn, label))
try:
write_json_file(self.filter_requested_info(info_dict), infofn)
return True
except (OSError, IOError):
self.report_error(msg('Cannot write %s to JSON file ' + infofn, label))
return
def _write_thumbnails(self, info_dict, filename): def _write_thumbnails(self, info_dict, filename):
if self.params.get('writethumbnail', False): if self.params.get('writethumbnail', False):
thumbnails = info_dict.get('thumbnails') thumbnails = info_dict.get('thumbnails')

View File

@ -21,6 +21,7 @@ import socket
import struct import struct
import subprocess import subprocess
import sys import sys
import types
import xml.etree.ElementTree import xml.etree.ElementTree
# naming convention # naming convention
@ -55,6 +56,22 @@ try:
except ImportError: # Python 2 except ImportError: # Python 2
import urllib2 as compat_urllib_request import urllib2 as compat_urllib_request
# Also fix up lack of method arg in old Pythons
try:
_req = compat_urllib_request.Request
_req('http://127.0.0.1', method='GET')
except TypeError:
class _request(object):
def __new__(cls, url, *args, **kwargs):
method = kwargs.pop('method', None)
r = _req(url, *args, **kwargs)
if method:
r.get_method = types.MethodType(lambda _: method, r)
return r
compat_urllib_request.Request = _request
try: try:
import urllib.error as compat_urllib_error import urllib.error as compat_urllib_error
except ImportError: # Python 2 except ImportError: # Python 2
@ -79,6 +96,12 @@ try:
except ImportError: # Python 2 except ImportError: # Python 2
import urllib as compat_urllib_response import urllib as compat_urllib_response
try:
compat_urllib_response.addinfourl.status
except AttributeError:
# .getcode() is deprecated in Py 3.
compat_urllib_response.addinfourl.status = property(lambda self: self.getcode())
try: try:
import http.cookiejar as compat_cookiejar import http.cookiejar as compat_cookiejar
except ImportError: # Python 2 except ImportError: # Python 2
@ -103,12 +126,24 @@ except ImportError: # Python 2
import Cookie as compat_cookies import Cookie as compat_cookies
compat_http_cookies = compat_cookies compat_http_cookies = compat_cookies
if sys.version_info[0] == 2: if sys.version_info[0] == 2 or sys.version_info < (3, 3):
class compat_cookies_SimpleCookie(compat_cookies.SimpleCookie): class compat_cookies_SimpleCookie(compat_cookies.SimpleCookie):
def load(self, rawdata): def load(self, rawdata):
must_have_value = 0
if not isinstance(rawdata, dict):
if sys.version_info[:2] != (2, 7):
# attribute must have value for parsing
rawdata, must_have_value = re.subn(
r'(?i)(;\s*)(secure|httponly)(\s*(?:;|$))', r'\1\2=\2\3', rawdata)
if sys.version_info[0] == 2:
if isinstance(rawdata, compat_str): if isinstance(rawdata, compat_str):
rawdata = str(rawdata) rawdata = str(rawdata)
return super(compat_cookies_SimpleCookie, self).load(rawdata) super(compat_cookies_SimpleCookie, self).load(rawdata)
if must_have_value > 0:
for morsel in self.values():
for attr in ('secure', 'httponly'):
if morsel.get(attr):
morsel[attr] = True
else: else:
compat_cookies_SimpleCookie = compat_cookies.SimpleCookie compat_cookies_SimpleCookie = compat_cookies.SimpleCookie
compat_http_cookies_SimpleCookie = compat_cookies_SimpleCookie compat_http_cookies_SimpleCookie = compat_cookies_SimpleCookie
@ -2360,6 +2395,11 @@ try:
import http.client as compat_http_client import http.client as compat_http_client
except ImportError: # Python 2 except ImportError: # Python 2
import httplib as compat_http_client import httplib as compat_http_client
try:
compat_http_client.HTTPResponse.getcode
except AttributeError:
# Py < 3.1
compat_http_client.HTTPResponse.getcode = lambda self: self.status
try: try:
from urllib.error import HTTPError as compat_HTTPError from urllib.error import HTTPError as compat_HTTPError

View File

@ -339,6 +339,10 @@ class FileDownloader(object):
def download(self, filename, info_dict): def download(self, filename, info_dict):
"""Download to a filename using the info from info_dict """Download to a filename using the info from info_dict
Return True on success and False otherwise Return True on success and False otherwise
This method filters the `Cookie` header from the info_dict to prevent leaks.
Downloaders have their own way of handling cookies.
See: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-v8mc-9377-rwjj
""" """
nooverwrites_and_exists = ( nooverwrites_and_exists = (

View File

@ -1,9 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os.path import os
import re import re
import subprocess import subprocess
import sys import sys
import tempfile
import time import time
from .common import FileDownloader from .common import FileDownloader
@ -23,6 +24,8 @@ from ..utils import (
check_executable, check_executable,
is_outdated_version, is_outdated_version,
process_communicate_or_kill, process_communicate_or_kill,
T,
traverse_obj,
) )
@ -30,6 +33,7 @@ class ExternalFD(FileDownloader):
def real_download(self, filename, info_dict): def real_download(self, filename, info_dict):
self.report_destination(filename) self.report_destination(filename)
tmpfilename = self.temp_name(filename) tmpfilename = self.temp_name(filename)
self._cookies_tempfile = None
try: try:
started = time.time() started = time.time()
@ -42,6 +46,13 @@ class ExternalFD(FileDownloader):
# should take place # should take place
retval = 0 retval = 0
self.to_screen('[%s] Interrupted by user' % self.get_basename()) self.to_screen('[%s] Interrupted by user' % self.get_basename())
finally:
if self._cookies_tempfile and os.path.isfile(self._cookies_tempfile):
try:
os.remove(self._cookies_tempfile)
except OSError:
self.report_warning(
'Unable to delete temporary cookies file "{0}"'.format(self._cookies_tempfile))
if retval == 0: if retval == 0:
status = { status = {
@ -97,6 +108,16 @@ class ExternalFD(FileDownloader):
def _configuration_args(self, default=[]): def _configuration_args(self, default=[]):
return cli_configuration_args(self.params, 'external_downloader_args', default) return cli_configuration_args(self.params, 'external_downloader_args', default)
def _write_cookies(self):
if not self.ydl.cookiejar.filename:
tmp_cookies = tempfile.NamedTemporaryFile(suffix='.cookies', delete=False)
tmp_cookies.close()
self._cookies_tempfile = tmp_cookies.name
self.to_screen('[download] Writing temporary cookies file to "{0}"'.format(self._cookies_tempfile))
# real_download resets _cookies_tempfile; if it's None, save() will write to cookiejar.filename
self.ydl.cookiejar.save(self._cookies_tempfile, ignore_discard=True, ignore_expires=True)
return self.ydl.cookiejar.filename or self._cookies_tempfile
def _call_downloader(self, tmpfilename, info_dict): def _call_downloader(self, tmpfilename, info_dict):
""" Either overwrite this or implement _make_cmd """ """ Either overwrite this or implement _make_cmd """
cmd = [encodeArgument(a) for a in self._make_cmd(tmpfilename, info_dict)] cmd = [encodeArgument(a) for a in self._make_cmd(tmpfilename, info_dict)]
@ -110,13 +131,21 @@ class ExternalFD(FileDownloader):
self.to_stderr(stderr.decode('utf-8', 'replace')) self.to_stderr(stderr.decode('utf-8', 'replace'))
return p.returncode return p.returncode
@staticmethod
def _header_items(info_dict):
return traverse_obj(
info_dict, ('http_headers', T(dict.items), Ellipsis))
class CurlFD(ExternalFD): class CurlFD(ExternalFD):
AVAILABLE_OPT = '-V' AVAILABLE_OPT = '-V'
def _make_cmd(self, tmpfilename, info_dict): def _make_cmd(self, tmpfilename, info_dict):
cmd = [self.exe, '--location', '-o', tmpfilename] cmd = [self.exe, '--location', '-o', tmpfilename, '--compressed']
for key, val in info_dict['http_headers'].items(): cookie_header = self.ydl.cookiejar.get_cookie_header(info_dict['url'])
if cookie_header:
cmd += ['--cookie', cookie_header]
for key, val in self._header_items(info_dict):
cmd += ['--header', '%s: %s' % (key, val)] cmd += ['--header', '%s: %s' % (key, val)]
cmd += self._bool_option('--continue-at', 'continuedl', '-', '0') cmd += self._bool_option('--continue-at', 'continuedl', '-', '0')
cmd += self._valueless_option('--silent', 'noprogress') cmd += self._valueless_option('--silent', 'noprogress')
@ -151,8 +180,11 @@ class AxelFD(ExternalFD):
def _make_cmd(self, tmpfilename, info_dict): def _make_cmd(self, tmpfilename, info_dict):
cmd = [self.exe, '-o', tmpfilename] cmd = [self.exe, '-o', tmpfilename]
for key, val in info_dict['http_headers'].items(): for key, val in self._header_items(info_dict):
cmd += ['-H', '%s: %s' % (key, val)] cmd += ['-H', '%s: %s' % (key, val)]
cookie_header = self.ydl.cookiejar.get_cookie_header(info_dict['url'])
if cookie_header:
cmd += ['-H', 'Cookie: {0}'.format(cookie_header), '--max-redirect=0']
cmd += self._configuration_args() cmd += self._configuration_args()
cmd += ['--', info_dict['url']] cmd += ['--', info_dict['url']]
return cmd return cmd
@ -162,8 +194,10 @@ class WgetFD(ExternalFD):
AVAILABLE_OPT = '--version' AVAILABLE_OPT = '--version'
def _make_cmd(self, tmpfilename, info_dict): def _make_cmd(self, tmpfilename, info_dict):
cmd = [self.exe, '-O', tmpfilename, '-nv', '--no-cookies'] cmd = [self.exe, '-O', tmpfilename, '-nv', '--compression=auto']
for key, val in info_dict['http_headers'].items(): if self.ydl.cookiejar.get_cookie_header(info_dict['url']):
cmd += ['--load-cookies', self._write_cookies()]
for key, val in self._header_items(info_dict):
cmd += ['--header', '%s: %s' % (key, val)] cmd += ['--header', '%s: %s' % (key, val)]
cmd += self._option('--limit-rate', 'ratelimit') cmd += self._option('--limit-rate', 'ratelimit')
retry = self._option('--tries', 'retries') retry = self._option('--tries', 'retries')
@ -182,20 +216,57 @@ class WgetFD(ExternalFD):
class Aria2cFD(ExternalFD): class Aria2cFD(ExternalFD):
AVAILABLE_OPT = '-v' AVAILABLE_OPT = '-v'
@staticmethod
def _aria2c_filename(fn):
return fn if os.path.isabs(fn) else os.path.join('.', fn)
def _make_cmd(self, tmpfilename, info_dict): def _make_cmd(self, tmpfilename, info_dict):
cmd = [self.exe, '-c'] cmd = [self.exe, '-c',
cmd += self._configuration_args([ '--console-log-level=warn', '--summary-interval=0', '--download-result=hide',
'--min-split-size', '1M', '--max-connection-per-server', '4']) '--http-accept-gzip=true', '--file-allocation=none', '-x16', '-j16', '-s16']
dn = os.path.dirname(tmpfilename) if 'fragments' in info_dict:
if dn: cmd += ['--allow-overwrite=true', '--allow-piece-length-change=true']
cmd += ['--dir', dn] else:
cmd += ['--out', os.path.basename(tmpfilename)] cmd += ['--min-split-size', '1M']
for key, val in info_dict['http_headers'].items():
if self.ydl.cookiejar.get_cookie_header(info_dict['url']):
cmd += ['--load-cookies={0}'.format(self._write_cookies())]
for key, val in self._header_items(info_dict):
cmd += ['--header', '%s: %s' % (key, val)] cmd += ['--header', '%s: %s' % (key, val)]
cmd += self._configuration_args(['--max-connection-per-server', '4'])
cmd += ['--out', os.path.basename(tmpfilename)]
cmd += self._option('--max-overall-download-limit', 'ratelimit')
cmd += self._option('--interface', 'source_address') cmd += self._option('--interface', 'source_address')
cmd += self._option('--all-proxy', 'proxy') cmd += self._option('--all-proxy', 'proxy')
cmd += self._bool_option('--check-certificate', 'nocheckcertificate', 'false', 'true', '=') cmd += self._bool_option('--check-certificate', 'nocheckcertificate', 'false', 'true', '=')
cmd += self._bool_option('--remote-time', 'updatetime', 'true', 'false', '=') cmd += self._bool_option('--remote-time', 'updatetime', 'true', 'false', '=')
cmd += self._bool_option('--show-console-readout', 'noprogress', 'false', 'true', '=')
cmd += self._configuration_args()
# aria2c strips out spaces from the beginning/end of filenames and paths.
# We work around this issue by adding a "./" to the beginning of the
# filename and relative path, and adding a "/" at the end of the path.
# See: https://github.com/yt-dlp/yt-dlp/issues/276
# https://github.com/ytdl-org/youtube-dl/issues/20312
# https://github.com/aria2/aria2/issues/1373
dn = os.path.dirname(tmpfilename)
if dn:
cmd += ['--dir', self._aria2c_filename(dn) + os.path.sep]
if 'fragments' not in info_dict:
cmd += ['--out', self._aria2c_filename(os.path.basename(tmpfilename))]
cmd += ['--auto-file-renaming=false']
if 'fragments' in info_dict:
cmd += ['--file-allocation=none', '--uri-selector=inorder']
url_list_file = '%s.frag.urls' % (tmpfilename, )
url_list = []
for frag_index, fragment in enumerate(info_dict['fragments']):
fragment_filename = '%s-Frag%d' % (os.path.basename(tmpfilename), frag_index)
url_list.append('%s\n\tout=%s' % (fragment['url'], self._aria2c_filename(fragment_filename)))
stream, _ = self.sanitize_open(url_list_file, 'wb')
stream.write('\n'.join(url_list).encode())
stream.close()
cmd += ['-i', self._aria2c_filename(url_list_file)]
else:
cmd += ['--', info_dict['url']] cmd += ['--', info_dict['url']]
return cmd return cmd
@ -235,8 +306,10 @@ class Aria2pFD(ExternalFD):
} }
options['dir'] = os.path.dirname(tmpfilename) or os.path.abspath('.') options['dir'] = os.path.dirname(tmpfilename) or os.path.abspath('.')
options['out'] = os.path.basename(tmpfilename) options['out'] = os.path.basename(tmpfilename)
if self.ydl.cookiejar.get_cookie_header(info_dict['url']):
options['load-cookies'] = self._write_cookies()
options['header'] = [] options['header'] = []
for key, val in info_dict['http_headers'].items(): for key, val in self._header_items(info_dict):
options['header'].append('{0}: {1}'.format(key, val)) options['header'].append('{0}: {1}'.format(key, val))
download = aria2.add_uris([info_dict['url']], options) download = aria2.add_uris([info_dict['url']], options)
status = { status = {
@ -265,8 +338,16 @@ class HttpieFD(ExternalFD):
def _make_cmd(self, tmpfilename, info_dict): def _make_cmd(self, tmpfilename, info_dict):
cmd = ['http', '--download', '--output', tmpfilename, info_dict['url']] cmd = ['http', '--download', '--output', tmpfilename, info_dict['url']]
for key, val in info_dict['http_headers'].items(): for key, val in self._header_items(info_dict):
cmd += ['%s:%s' % (key, val)] cmd += ['%s:%s' % (key, val)]
# httpie 3.1.0+ removes the Cookie header on redirect, so this should be safe for now. [1]
# If we ever need cookie handling for redirects, we can export the cookiejar into a session. [2]
# 1: https://github.com/httpie/httpie/security/advisories/GHSA-9w4w-cpc8-h2fq
# 2: https://httpie.io/docs/cli/sessions
cookie_header = self.ydl.cookiejar.get_cookie_header(info_dict['url'])
if cookie_header:
cmd += ['Cookie:%s' % cookie_header]
return cmd return cmd
@ -312,7 +393,14 @@ class FFmpegFD(ExternalFD):
# if end_time: # if end_time:
# args += ['-t', compat_str(end_time - start_time)] # args += ['-t', compat_str(end_time - start_time)]
if info_dict['http_headers'] and re.match(r'^https?://', url): cookies = self.ydl.cookiejar.get_cookies_for_url(url)
if cookies:
args.extend(['-cookies', ''.join(
'{0}={1}; path={2}; domain={3};\r\n'.format(
cookie.name, cookie.value, cookie.path, cookie.domain)
for cookie in cookies)])
if info_dict.get('http_headers') and re.match(r'^https?://', url):
# Trailing \r\n after each HTTP header is important to prevent warning from ffmpeg/avconv: # Trailing \r\n after each HTTP header is important to prevent warning from ffmpeg/avconv:
# [http @ 00000000003d2fa0] No trailing CRLF found in HTTP header. # [http @ 00000000003d2fa0] No trailing CRLF found in HTTP header.
headers = handle_youtubedl_headers(info_dict['http_headers']) headers = handle_youtubedl_headers(info_dict['http_headers'])

View File

@ -280,16 +280,16 @@ class JSInterpreter(object):
# make Py 2.6 conform to its lying documentation # make Py 2.6 conform to its lying documentation
if name == 'flags': if name == 'flags':
self.flags = self.__flags self.flags = self.__flags
return self.flags
elif name == 'pattern': elif name == 'pattern':
self.pattern = self.__pattern_txt self.pattern = self.__pattern_txt
return self.pattern
elif hasattr(self.__self, name):
v = getattr(self.__self, name)
setattr(self, name, v)
return v
elif name in ('groupindex', 'groups'): elif name in ('groupindex', 'groups'):
# in case these get set after a match?
if hasattr(self.__self, name):
setattr(self, name, getattr(self.__self, name))
else:
return 0 if name == 'groupindex' else {} return 0 if name == 'groupindex' else {}
if hasattr(self, name):
return getattr(self, name)
raise AttributeError('{0} has no attribute named {1}'.format(self, name)) raise AttributeError('{0} has no attribute named {1}'.format(self, name))
@classmethod @classmethod

View File

@ -544,12 +544,14 @@ def parseOpts(overrideArguments=None):
workarounds.add_option( workarounds.add_option(
'--referer', '--referer',
metavar='URL', dest='referer', default=None, metavar='URL', dest='referer', default=None,
help='Specify a custom referer, use if the video access is restricted to one domain', help='Specify a custom Referer: use if the video access is restricted to one domain',
) )
workarounds.add_option( workarounds.add_option(
'--add-header', '--add-header',
metavar='FIELD:VALUE', dest='headers', action='append', metavar='FIELD:VALUE', dest='headers', action='append',
help='Specify a custom HTTP header and its value, separated by a colon \':\'. You can use this option multiple times', help=('Specify a custom HTTP header and its value, separated by a colon \':\'. You can use this option multiple times. '
'NB Use --cookies rather than adding a Cookie header if its contents may be sensitive; '
'data from a Cookie header will be sent to all domains, not just the one intended')
) )
workarounds.add_option( workarounds.add_option(
'--bidi-workaround', '--bidi-workaround',

View File

@ -16,6 +16,7 @@ import email.header
import errno import errno
import functools import functools
import gzip import gzip
import inspect
import io import io
import itertools import itertools
import json import json
@ -40,7 +41,6 @@ import zlib
from .compat import ( from .compat import (
compat_HTMLParseError, compat_HTMLParseError,
compat_HTMLParser, compat_HTMLParser,
compat_HTTPError,
compat_basestring, compat_basestring,
compat_casefold, compat_casefold,
compat_chr, compat_chr,
@ -63,6 +63,7 @@ from .compat import (
compat_struct_pack, compat_struct_pack,
compat_struct_unpack, compat_struct_unpack,
compat_urllib_error, compat_urllib_error,
compat_urllib_HTTPError,
compat_urllib_parse, compat_urllib_parse,
compat_urllib_parse_parse_qs as compat_parse_qs, compat_urllib_parse_parse_qs as compat_parse_qs,
compat_urllib_parse_urlencode, compat_urllib_parse_urlencode,
@ -2613,7 +2614,8 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
Part of this code was copied from: Part of this code was copied from:
http://techknack.net/python-urllib2-handlers/ http://techknack.net/python-urllib2-handlers/, archived at
https://web.archive.org/web/20130527205558/http://techknack.net/python-urllib2-handlers/
Andrew Rowls, the author of that code, agreed to release it to the Andrew Rowls, the author of that code, agreed to release it to the
public domain. public domain.
@ -2671,7 +2673,9 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
req._Request__original = req._Request__original.partition('#')[0] req._Request__original = req._Request__original.partition('#')[0]
req._Request__r_type = req._Request__r_type.partition('#')[0] req._Request__r_type = req._Request__r_type.partition('#')[0]
return req # Use the totally undocumented AbstractHTTPHandler per
# https://github.com/yt-dlp/yt-dlp/pull/4158
return compat_urllib_request.AbstractHTTPHandler.do_request_(self, req)
def http_response(self, req, resp): def http_response(self, req, resp):
old_resp = resp old_resp = resp
@ -2682,7 +2686,7 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
try: try:
uncompressed = io.BytesIO(gz.read()) uncompressed = io.BytesIO(gz.read())
except IOError as original_ioerror: except IOError as original_ioerror:
# There may be junk add the end of the file # There may be junk at the end of the file
# See http://stackoverflow.com/q/4928560/35070 for details # See http://stackoverflow.com/q/4928560/35070 for details
for i in range(1, 1024): for i in range(1, 1024):
try: try:
@ -2709,8 +2713,7 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
if location: if location:
# As of RFC 2616 default charset is iso-8859-1 that is respected by python 3 # As of RFC 2616 default charset is iso-8859-1 that is respected by python 3
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
location = location.encode('iso-8859-1').decode('utf-8') location = location.encode('iso-8859-1')
else:
location = location.decode('utf-8') location = location.decode('utf-8')
location_escaped = escape_url(location) location_escaped = escape_url(location)
if location != location_escaped: if location != location_escaped:
@ -2909,6 +2912,19 @@ class YoutubeDLCookieJar(compat_cookiejar.MozillaCookieJar):
cookie.expires = None cookie.expires = None
cookie.discard = True cookie.discard = True
def get_cookie_header(self, url):
"""Generate a Cookie HTTP header for a given url"""
cookie_req = sanitized_Request(url)
self.add_cookie_header(cookie_req)
return cookie_req.get_header('Cookie')
def get_cookies_for_url(self, url):
"""Generate a list of Cookie objects for a given url"""
# Policy `_now` attribute must be set before calling `_cookies_for_request`
# Ref: https://github.com/python/cpython/blob/3.7/Lib/http/cookiejar.py#L1360
self._policy._now = self._now = int(time.time())
return self._cookies_for_request(sanitized_Request(url))
class YoutubeDLCookieProcessor(compat_urllib_request.HTTPCookieProcessor): class YoutubeDLCookieProcessor(compat_urllib_request.HTTPCookieProcessor):
def __init__(self, cookiejar=None): def __init__(self, cookiejar=None):
@ -2939,17 +2955,16 @@ class YoutubeDLRedirectHandler(compat_urllib_request.HTTPRedirectHandler):
The code is based on HTTPRedirectHandler implementation from CPython [1]. The code is based on HTTPRedirectHandler implementation from CPython [1].
This redirect handler solves two issues: This redirect handler fixes and improves the logic to better align with RFC7261
- ensures redirect URL is always unicode under python 2 and what browsers tend to do [2][3]
- introduces support for experimental HTTP response status code
308 Permanent Redirect [2] used by some sites [3]
1. https://github.com/python/cpython/blob/master/Lib/urllib/request.py 1. https://github.com/python/cpython/blob/master/Lib/urllib/request.py
2. https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/308 2. https://datatracker.ietf.org/doc/html/rfc7231
3. https://github.com/ytdl-org/youtube-dl/issues/28768 3. https://github.com/python/cpython/issues/91306
""" """
http_error_301 = http_error_303 = http_error_307 = http_error_308 = compat_urllib_request.HTTPRedirectHandler.http_error_302 # Supply possibly missing alias
http_error_308 = compat_urllib_request.HTTPRedirectHandler.http_error_302
def redirect_request(self, req, fp, code, msg, headers, newurl): def redirect_request(self, req, fp, code, msg, headers, newurl):
"""Return a Request or None in response to a redirect. """Return a Request or None in response to a redirect.
@ -2961,19 +2976,15 @@ class YoutubeDLRedirectHandler(compat_urllib_request.HTTPRedirectHandler):
else should try to handle this url. Return None if you can't else should try to handle this url. Return None if you can't
but another Handler might. but another Handler might.
""" """
m = req.get_method() if code not in (301, 302, 303, 307, 308):
if (not (code in (301, 302, 303, 307, 308) and m in ("GET", "HEAD") raise compat_urllib_HTTPError(req.full_url, code, msg, headers, fp)
or code in (301, 302, 303) and m == "POST")):
raise compat_HTTPError(req.full_url, code, msg, headers, fp) new_method = req.get_method()
# Strictly (according to RFC 2616), 301 or 302 in response to new_data = req.data
# a POST MUST NOT cause a redirection without confirmation
# from the user (of urllib.request, in this case). In practice,
# essentially all clients do redirect in this case, so we do
# the same.
# On python 2 urlh.geturl() may sometimes return redirect URL # On python 2 urlh.geturl() may sometimes return redirect URL
# as byte string instead of unicode. This workaround allows # as a byte string instead of unicode. This workaround forces
# to force it always return unicode. # it to return unicode.
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
newurl = compat_str(newurl) newurl = compat_str(newurl)
@ -2982,13 +2993,34 @@ class YoutubeDLRedirectHandler(compat_urllib_request.HTTPRedirectHandler):
# but it is kept for compatibility with other callers. # but it is kept for compatibility with other callers.
newurl = newurl.replace(' ', '%20') newurl = newurl.replace(' ', '%20')
CONTENT_HEADERS = ("content-length", "content-type") # Technically the Cookie header should be in unredirected_hdrs;
# however in practice some may set it in normal headers anyway.
# We will remove it here to prevent any leaks.
remove_headers = ['Cookie']
# A 303 must either use GET or HEAD for subsequent request
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.4
if code == 303 and req.get_method() != 'HEAD':
new_method = 'GET'
# 301 and 302 redirects are commonly turned into a GET from a POST
# for subsequent requests by browsers, so we'll do the same.
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.2
# https://datatracker.ietf.org/doc/html/rfc7231#section-6.4.3
elif code in (301, 302) and req.get_method() == 'POST':
new_method = 'GET'
# only remove payload if method changed (e.g. POST to GET)
if new_method != req.get_method():
new_data = None
remove_headers.extend(['Content-Length', 'Content-Type'])
# NB: don't use dict comprehension for python 2.6 compatibility # NB: don't use dict comprehension for python 2.6 compatibility
newheaders = dict((k, v) for k, v in req.headers.items() new_headers = dict((k, v) for k, v in req.header_items()
if k.lower() not in CONTENT_HEADERS) if k.title() not in remove_headers)
return compat_urllib_request.Request( return compat_urllib_request.Request(
newurl, headers=newheaders, origin_req_host=req.origin_req_host, newurl, headers=new_headers, origin_req_host=req.origin_req_host,
unverifiable=True) unverifiable=True, method=new_method, data=new_data)
def extract_timezone(date_str): def extract_timezone(date_str):
@ -3881,7 +3913,7 @@ def detect_exe_version(output, version_re=None, unrecognized='present'):
return unrecognized return unrecognized
class LazyList(compat_collections_abc.Sequence): class LazyList(compat_collections_abc.Iterable):
"""Lazy immutable list from an iterable """Lazy immutable list from an iterable
Note that slices of a LazyList are lists and not LazyList""" Note that slices of a LazyList are lists and not LazyList"""
@ -4223,10 +4255,16 @@ def multipart_encode(data, boundary=None):
return out, content_type return out, content_type
def variadic(x, allowed_types=(compat_str, bytes, dict)): def is_iterable_like(x, allowed_types=compat_collections_abc.Iterable, blocked_types=NO_DEFAULT):
if not isinstance(allowed_types, tuple) and isinstance(allowed_types, compat_collections_abc.Iterable): if blocked_types is NO_DEFAULT:
blocked_types = (compat_str, bytes, compat_collections_abc.Mapping)
return isinstance(x, allowed_types) and not isinstance(x, blocked_types)
def variadic(x, allowed_types=NO_DEFAULT):
if isinstance(allowed_types, compat_collections_abc.Iterable):
allowed_types = tuple(allowed_types) allowed_types = tuple(allowed_types)
return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,) return x if is_iterable_like(x, blocked_types=allowed_types) else (x,)
def dict_get(d, key_or_keys, default=None, skip_false_values=True): def dict_get(d, key_or_keys, default=None, skip_false_values=True):
@ -5993,7 +6031,7 @@ def clean_podcast_url(url):
def traverse_obj(obj, *paths, **kwargs): def traverse_obj(obj, *paths, **kwargs):
""" """
Safely traverse nested `dict`s and `Sequence`s Safely traverse nested `dict`s and `Iterable`s
>>> obj = [{}, {"key": "value"}] >>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key")) >>> traverse_obj(obj, (1, "key"))
@ -6001,14 +6039,17 @@ def traverse_obj(obj, *paths, **kwargs):
Each of the provided `paths` is tested and the first producing a valid result will be returned. Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found. The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
A value of None is treated as the absence of a value. Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
The keys in the path can be one of: The keys in the path can be one of:
- `None`: Return the current object. - `None`: Return the current object.
- `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`. - `set`: Requires the only item in the set to be a type or function,
like `{type}`/`{func}`. If a `type`, returns only values
of this type. If a function, returns `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`. - `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values. - `Ellipsis`: Branch out and return a list of all values.
- `tuple`/`list`: Branch out and return a list of all matching values. - `tuple`/`list`: Branch out and return a list of all matching values.
@ -6016,6 +6057,9 @@ def traverse_obj(obj, *paths, **kwargs):
- `function`: Branch out and return values filtered by the function. - `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`. Read as: `[value for key, value in obj if function(key, value)]`.
For `Sequence`s, `key` is the index of the value. For `Sequence`s, `key` is the index of the value.
For `Iterable`s, `key` is the enumeration count of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict` Transform the current object and return a matching dict. - `dict` Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
@ -6024,8 +6068,12 @@ def traverse_obj(obj, *paths, **kwargs):
@params paths Paths which to traverse by. @params paths Paths which to traverse by.
Keyword arguments: Keyword arguments:
@param default Value to return if the paths do not match. @param default Value to return if the paths do not match.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, depth first. Try to avoid if using nested `dict` keys.
@param expected_type If a `type`, only accept final values of this type. @param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result. If any other callable, try to call the function on each result.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, recursively. This does respect branching paths.
@param get_all If `False`, return the first matching result, otherwise all matching ones. @param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive. @param casesense If `False`, consider string dictionary keys as case insensitive.
@ -6036,12 +6084,15 @@ def traverse_obj(obj, *paths, **kwargs):
@param _traverse_string Whether to traverse into objects as strings. @param _traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be If `True`, any non-compatible object will first be
converted into a string and then traversed into. converted into a string and then traversed into.
The return value of that path will be a string instead,
not respecting any further branching.
@returns The result of the object traversal. @returns The result of the object traversal.
If successful, `get_all=True`, and the path branches at least once, If successful, `get_all=True`, and the path branches at least once,
then a list of results is returned instead. then a list of results is returned instead.
A list is always returned if the last path branches and no `default` is given. A list is always returned if the last path branches and no `default` is given.
If a path ends on a `dict` that result will always be a `dict`.
""" """
# parameter defaults # parameter defaults
@ -6055,7 +6106,6 @@ def traverse_obj(obj, *paths, **kwargs):
# instant compat # instant compat
str = compat_str str = compat_str
is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes))
casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k
if isinstance(expected_type, type): if isinstance(expected_type, type):
@ -6063,128 +6113,184 @@ def traverse_obj(obj, *paths, **kwargs):
else: else:
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def lookup_or_none(v, k, getter=None):
try:
return getter(v, k) if getter else v[k]
except IndexError:
return None
def from_iterable(iterables): def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
for it in iterables: for it in iterables:
for item in it: for item in it:
yield item yield item
def apply_key(key, obj): def apply_key(key, obj, is_last):
if obj is None: branching = False
return
if obj is None and _traverse_string:
if key is Ellipsis or callable(key) or isinstance(key, slice):
branching = True
result = ()
else:
result = None
elif key is None: elif key is None:
yield obj result = obj
elif isinstance(key, set):
assert len(key) == 1, 'Set should only be used to wrap a single item'
item = next(iter(key))
if isinstance(item, type):
result = obj if isinstance(obj, item) else None
else:
result = try_call(item, args=(obj,))
elif isinstance(key, (list, tuple)): elif isinstance(key, (list, tuple)):
for branch in key: branching = True
_, result = apply_path(obj, branch) result = from_iterable(
for item in result: apply_path(obj, branch, is_last)[0] for branch in key)
yield item
elif key is Ellipsis: elif key is Ellipsis:
result = [] branching = True
if isinstance(obj, compat_collections_abc.Mapping): if isinstance(obj, compat_collections_abc.Mapping):
result = obj.values() result = obj.values()
elif is_sequence(obj): elif is_iterable_like(obj):
result = obj result = obj
elif isinstance(obj, compat_re_Match): elif isinstance(obj, compat_re_Match):
result = obj.groups() result = obj.groups()
elif _traverse_string: elif _traverse_string:
branching = False
result = str(obj) result = str(obj)
for item in result: else:
yield item result = ()
elif callable(key): elif callable(key):
if is_sequence(obj): branching = True
iter_obj = enumerate(obj) if isinstance(obj, compat_collections_abc.Mapping):
elif isinstance(obj, compat_collections_abc.Mapping):
iter_obj = obj.items() iter_obj = obj.items()
elif is_iterable_like(obj):
iter_obj = enumerate(obj)
elif isinstance(obj, compat_re_Match): elif isinstance(obj, compat_re_Match):
iter_obj = enumerate(itertools.chain([obj.group()], obj.groups())) iter_obj = itertools.chain(
enumerate(itertools.chain((obj.group(),), obj.groups())),
obj.groupdict().items())
elif _traverse_string: elif _traverse_string:
branching = False
iter_obj = enumerate(str(obj)) iter_obj = enumerate(str(obj))
else: else:
return iter_obj = ()
for item in (v for k, v in iter_obj if try_call(key, args=(k, v))):
yield item result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
if not branching: # string traversal
result = ''.join(result)
elif isinstance(key, dict): elif isinstance(key, dict):
iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
yield dict((k, v if v is not None else default) for k, v in iter_obj result = dict((k, v if v is not None else default) for k, v in iter_obj
if v is not None or default is not NO_DEFAULT) if v is not None or default is not NO_DEFAULT) or None
elif isinstance(obj, compat_collections_abc.Mapping): elif isinstance(obj, compat_collections_abc.Mapping):
yield (obj.get(key) if casesense or (key in obj) result = (try_call(obj.get, args=(key,))
if casesense or try_call(obj.__contains__, args=(key,))
else next((v for k, v in obj.items() if casefold(k) == key), None)) else next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, compat_re_Match): elif isinstance(obj, compat_re_Match):
result = None
if isinstance(key, int) or casesense: if isinstance(key, int) or casesense:
try: # Py 2.6 doesn't have methods in the Match class/type
yield obj.group(key) result = lookup_or_none(obj, key, getter=lambda _, k: obj.group(k))
return
except IndexError:
pass
if not isinstance(key, str):
return
yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) elif isinstance(key, str):
result = next((v for k, v in obj.groupdict().items()
if casefold(k) == key), None)
else: else:
if _is_user_input: result = None
key = (int_or_none(key) if ':' not in key if isinstance(key, (int, slice)):
else slice(*map(int_or_none, key.split(':')))) if is_iterable_like(obj, compat_collections_abc.Sequence):
branching = isinstance(key, slice)
result = lookup_or_none(obj, key)
elif _traverse_string:
result = lookup_or_none(str(obj), key)
if not isinstance(key, (int, slice)): return branching, result if branching else (result,)
def lazy_last(iterable):
iterator = iter(iterable)
prev = next(iterator, NO_DEFAULT)
if prev is NO_DEFAULT:
return return
if not is_sequence(obj): for item in iterator:
if not _traverse_string: yield False, prev
return prev = item
obj = str(obj)
try: yield True, prev
yield obj[key]
except IndexError:
pass
def apply_path(start_obj, path): def apply_path(start_obj, path, test_type):
objs = (start_obj,) objs = (start_obj,)
has_branched = False has_branched = False
for key in variadic(path): key = None
if _is_user_input and key == ':': for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
if _is_user_input and isinstance(key, str):
if key == ':':
key = Ellipsis key = Ellipsis
elif ':' in key:
key = slice(*map(int_or_none, key.split(':')))
elif int_or_none(key) is not None:
key = int(key)
if not casesense and isinstance(key, str): if not casesense and isinstance(key, str):
key = compat_casefold(key) key = compat_casefold(key)
if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key): if __debug__ and callable(key):
has_branched = True # Verify function signature
args = inspect.getargspec(key)
if len(args.args) != 2:
# crash differently in 2.6 !
inspect.getcallargs(key, None, None)
key_func = functools.partial(apply_key, key) new_objs = []
objs = from_iterable(map(key_func, objs)) for obj in objs:
branching, results = apply_key(key, obj, last)
has_branched |= branching
new_objs.append(results)
return has_branched, objs objs = from_iterable(new_objs)
def _traverse_obj(obj, path, use_list=True): if test_type and not isinstance(key, (dict, list, tuple)):
has_branched, results = apply_path(obj, path) objs = map(type_test, objs)
results = LazyList(x for x in map(type_test, results) if x is not None)
return objs, has_branched, isinstance(key, dict)
def _traverse_obj(obj, path, allow_empty, test_type):
results, has_branched, is_dict = apply_path(obj, path, test_type)
results = LazyList(x for x in results if x not in (None, {}))
if get_all and has_branched: if get_all and has_branched:
return results.exhaust() if results or use_list else None if results:
return results.exhaust()
if allow_empty:
return [] if default is NO_DEFAULT else default
return None
return results[0] if results else None return results[0] if results else {} if allow_empty and is_dict else None
for index, path in enumerate(paths, 1): for index, path in enumerate(paths, 1):
use_list = default is NO_DEFAULT and index == len(paths) result = _traverse_obj(obj, path, index == len(paths), True)
result = _traverse_obj(obj, path, use_list)
if result is not None: if result is not None:
return result return result
return None if default is NO_DEFAULT else default return None if default is NO_DEFAULT else default
def T(x):
""" For use in yt-dl instead of {type} or set((type,)) """
return set((x,))
def get_first(obj, keys, **kwargs): def get_first(obj, keys, **kwargs):
return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs) return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)