Add ability to read IDs from files
This commit is contained in:
parent
b58eebb51f
commit
1a4ff07f78
@ -6,9 +6,9 @@ import sys
|
||||
import click
|
||||
|
||||
from bdfr.archiver import Archiver
|
||||
from bdfr.cloner import RedditCloner
|
||||
from bdfr.configuration import Configuration
|
||||
from bdfr.downloader import RedditDownloader
|
||||
from bdfr.cloner import RedditCloner
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@ -17,6 +17,7 @@ _common_options = [
|
||||
click.option('--authenticate', is_flag=True, default=None),
|
||||
click.option('--config', type=str, default=None),
|
||||
click.option('--disable-module', multiple=True, default=None, type=str),
|
||||
click.option('--include-id-file', multiple=True, default=None),
|
||||
click.option('--log', type=str, default=None),
|
||||
click.option('--saved', is_flag=True, default=None),
|
||||
click.option('--search', default=None, type=str),
|
||||
@ -26,12 +27,12 @@ _common_options = [
|
||||
click.option('-L', '--limit', default=None, type=int),
|
||||
click.option('-l', '--link', multiple=True, default=None, type=str),
|
||||
click.option('-m', '--multireddit', multiple=True, default=None, type=str),
|
||||
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', 'controversial', 'rising', 'relevance')),
|
||||
default=None),
|
||||
click.option('-s', '--subreddit', multiple=True, default=None, type=str),
|
||||
click.option('-v', '--verbose', default=None, count=True),
|
||||
click.option('-u', '--user', type=str, multiple=True, default=None),
|
||||
click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None),
|
||||
click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new',
|
||||
'controversial', 'rising', 'relevance')), default=None),
|
||||
click.option('-u', '--user', type=str, multiple=True, default=None),
|
||||
click.option('-v', '--verbose', default=None, count=True),
|
||||
]
|
||||
|
||||
_downloader_options = [
|
||||
|
@ -18,6 +18,7 @@ class Configuration(Namespace):
|
||||
self.exclude_id_file = []
|
||||
self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}'
|
||||
self.folder_scheme: str = '{SUBREDDIT}'
|
||||
self.include_id_file = []
|
||||
self.limit: Optional[int] = None
|
||||
self.link: list[str] = []
|
||||
self.log: Optional[str] = None
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import configparser
|
||||
import importlib.resources
|
||||
import itertools
|
||||
import logging
|
||||
import logging.handlers
|
||||
import re
|
||||
@ -78,7 +79,12 @@ class RedditConnector(metaclass=ABCMeta):
|
||||
self.create_reddit_instance()
|
||||
self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user]))
|
||||
|
||||
self.excluded_submission_ids = self.read_excluded_ids()
|
||||
self.excluded_submission_ids = set.union(
|
||||
self.read_id_files(self.args.exclude_id_file),
|
||||
set(self.args.exclude_id),
|
||||
)
|
||||
|
||||
self.args.link = list(itertools.chain(self.args.link, self.read_id_files(self.args.include_id_file)))
|
||||
|
||||
self.master_hash_list = {}
|
||||
self.authenticator = self.create_authenticator()
|
||||
@ -403,13 +409,13 @@ class RedditConnector(metaclass=ABCMeta):
|
||||
except prawcore.Forbidden:
|
||||
raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped')
|
||||
|
||||
def read_excluded_ids(self) -> set[str]:
|
||||
@staticmethod
|
||||
def read_id_files(file_locations: list[str]) -> set[str]:
|
||||
out = []
|
||||
out.extend(self.args.exclude_id)
|
||||
for id_file in self.args.exclude_id_file:
|
||||
for id_file in file_locations:
|
||||
id_file = Path(id_file).resolve().expanduser()
|
||||
if not id_file.exists():
|
||||
logger.warning(f'ID exclusion file at {id_file} does not exist')
|
||||
logger.warning(f'ID file at {id_file} does not exist')
|
||||
continue
|
||||
with open(id_file, 'r') as file:
|
||||
for line in file:
|
||||
|
@ -306,3 +306,17 @@ def test_cli_download_disable_modules(test_args: list[str], tmp_path: Path):
|
||||
assert result.exit_code == 0
|
||||
assert 'skipped due to disabled module' in result.output
|
||||
assert 'Downloaded submission' not in result.output
|
||||
|
||||
|
||||
@pytest.mark.online
|
||||
@pytest.mark.reddit
|
||||
@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests')
|
||||
def test_cli_download_include_id_file(tmp_path: Path):
|
||||
test_file = Path(tmp_path, 'include.txt')
|
||||
test_args = ['--include-id-file', str(test_file)]
|
||||
test_file.write_text('odr9wg\nody576')
|
||||
runner = CliRunner()
|
||||
test_args = create_basic_args_for_download_runner(test_args, tmp_path)
|
||||
result = runner.invoke(cli, test_args)
|
||||
assert result.exit_code == 0
|
||||
assert 'Downloaded submission' in result.output
|
||||
|
@ -339,11 +339,10 @@ def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: se
|
||||
assert results == expected
|
||||
|
||||
|
||||
def test_read_excluded_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
|
||||
def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
|
||||
test_file = tmp_path / 'test.txt'
|
||||
test_file.write_text('aaaaaa\nbbbbbb')
|
||||
downloader_mock.args.exclude_id_file = [test_file]
|
||||
results = RedditConnector.read_excluded_ids(downloader_mock)
|
||||
results = RedditConnector.read_id_files([str(test_file)])
|
||||
assert results == {'aaaaaa', 'bbbbbb'}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user