Refactor method to remove max wait time

This commit is contained in:
Serene-Arc 2021-07-27 14:02:30 +10:00
parent 3cdae99490
commit dbe8733fd4
7 changed files with 10 additions and 7 deletions

View File

@ -6,6 +6,7 @@ import logging
import re import re
import time import time
import urllib.parse import urllib.parse
from collections import namedtuple
from typing import Callable, Optional from typing import Callable, Optional
import _hashlib import _hashlib
@ -29,7 +30,9 @@ class Resource:
self.extension = self._determine_extension() self.extension = self._determine_extension()
@staticmethod @staticmethod
def retry_download(url: str, max_wait_time: int) -> Callable: def retry_download(url: str) -> Callable:
max_wait_time = 300
def http_download() -> Optional[bytes]: def http_download() -> Optional[bytes]:
current_wait_time = 60 current_wait_time = 60
while True: while True:

View File

@ -14,4 +14,4 @@ class Direct(BaseDownloader):
super().__init__(post) super().__init__(post)
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
return [Resource(self.post, self.post.url, Resource.retry_download(self.post.url, 300))] return [Resource(self.post, self.post.url, Resource.retry_download(self.post.url))]

View File

@ -29,7 +29,7 @@ class Erome(BaseDownloader):
for link in links: for link in links:
if not re.match(r'https?://.*', link): if not re.match(r'https?://.*', link):
link = 'https://' + link link = 'https://' + link
out.append(Resource(self.post, link, Resource.retry_download(link, 300))) out.append(Resource(self.post, link, Resource.retry_download(link)))
return out return out
@staticmethod @staticmethod

View File

@ -31,7 +31,7 @@ class Gallery(BaseDownloader):
if not image_urls: if not image_urls:
raise SiteDownloaderError('No images found in Reddit gallery') raise SiteDownloaderError('No images found in Reddit gallery')
return [Resource(self.post, url, Resource.retry_download(url, 300)) for url in image_urls] return [Resource(self.post, url, Resource.retry_download(url)) for url in image_urls]
@ staticmethod @ staticmethod
def _get_links(id_dict: list[dict]) -> list[str]: def _get_links(id_dict: list[dict]) -> list[str]:

View File

@ -33,7 +33,7 @@ class Imgur(BaseDownloader):
def _compute_image_url(self, image: dict) -> Resource: def _compute_image_url(self, image: dict) -> Resource:
image_url = 'https://i.imgur.com/' + image['hash'] + self._validate_extension(image['ext']) image_url = 'https://i.imgur.com/' + image['hash'] + self._validate_extension(image['ext'])
return Resource(self.post, image_url, Resource.retry_download(image_url, 300)) return Resource(self.post, image_url, Resource.retry_download(image_url))
@staticmethod @staticmethod
def _get_data(link: str) -> dict: def _get_data(link: str) -> dict:

View File

@ -18,7 +18,7 @@ class Redgifs(BaseDownloader):
def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]:
media_url = self._get_link(self.post.url) media_url = self._get_link(self.post.url)
return [Resource(self.post, media_url, Resource.retry_download(media_url, 300), '.mp4')] return [Resource(self.post, media_url, Resource.retry_download(media_url), '.mp4')]
@staticmethod @staticmethod
def _get_link(url: str) -> str: def _get_link(url: str) -> str:

View File

@ -31,6 +31,6 @@ def test_resource_get_extension(test_url: str, expected: str):
('https://www.iana.org/_img/2013.1/iana-logo-header.svg', '426b3ac01d3584c820f3b7f5985d6623'), ('https://www.iana.org/_img/2013.1/iana-logo-header.svg', '426b3ac01d3584c820f3b7f5985d6623'),
)) ))
def test_download_online_resource(test_url: str, expected_hash: str): def test_download_online_resource(test_url: str, expected_hash: str):
test_resource = Resource(MagicMock(), test_url, Resource.retry_download(test_url, 60)) test_resource = Resource(MagicMock(), test_url, Resource.retry_download(test_url))
test_resource.download() test_resource.download()
assert test_resource.hash.hexdigest() == expected_hash assert test_resource.hash.hexdigest() == expected_hash