diff --git a/bdfr/__main__.py b/bdfr/__main__.py index 2823ce1..dadba51 100644 --- a/bdfr/__main__.py +++ b/bdfr/__main__.py @@ -111,9 +111,10 @@ def cli_download(context: click.Context, **_): """Used to download content posted to Reddit.""" config = Configuration() config.process_click_arguments(context) - setup_logging(config.verbose) + silence_module_loggers() + stream = make_console_logging_handler(config.verbose) try: - reddit_downloader = RedditDownloader(config) + reddit_downloader = RedditDownloader(config, [stream]) reddit_downloader.download() except Exception: logger.exception("Downloader exited unexpectedly") @@ -131,9 +132,10 @@ def cli_archive(context: click.Context, **_): """Used to archive post data from Reddit.""" config = Configuration() config.process_click_arguments(context) - setup_logging(config.verbose) + silence_module_loggers() + stream = make_console_logging_handler(config.verbose) try: - reddit_archiver = Archiver(config) + reddit_archiver = Archiver(config, [stream]) reddit_archiver.download() except Exception: logger.exception("Archiver exited unexpectedly") @@ -152,9 +154,10 @@ def cli_clone(context: click.Context, **_): """Combines archive and download commands.""" config = Configuration() config.process_click_arguments(context) - setup_logging(config.verbose) + silence_module_loggers() + stream = make_console_logging_handler(config.verbose) try: - reddit_scraper = RedditCloner(config) + reddit_scraper = RedditCloner(config, [stream]) reddit_scraper.download() except Exception: logger.exception("Scraper exited unexpectedly") @@ -187,7 +190,7 @@ def cli_completion(shell: str, uninstall: bool): Completion(shell).install() -def setup_logging(verbosity: int): +def make_console_logging_handler(verbosity: int) -> logging.StreamHandler: class StreamExceptionFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: result = not (record.levelno == logging.ERROR and record.exc_info) @@ -200,13 +203,16 @@ def setup_logging(verbosity: int): formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s") stream.setFormatter(formatter) - logger.addHandler(stream) if verbosity <= 0: stream.setLevel(logging.INFO) elif verbosity == 1: stream.setLevel(logging.DEBUG) else: stream.setLevel(9) + return stream + + +def silence_module_loggers(): logging.getLogger("praw").setLevel(logging.CRITICAL) logging.getLogger("prawcore").setLevel(logging.CRITICAL) logging.getLogger("urllib3").setLevel(logging.CRITICAL) diff --git a/bdfr/archiver.py b/bdfr/archiver.py index be5a445..88136a8 100644 --- a/bdfr/archiver.py +++ b/bdfr/archiver.py @@ -5,7 +5,7 @@ import json import logging import re from time import sleep -from typing import Iterator, Union +from typing import Iterable, Iterator, Union import dict2xml import praw.models @@ -24,8 +24,8 @@ logger = logging.getLogger(__name__) class Archiver(RedditConnector): - def __init__(self, args: Configuration): - super(Archiver, self).__init__(args) + def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()): + super(Archiver, self).__init__(args, logging_handlers) def download(self): for generator in self.reddit_lists: diff --git a/bdfr/cloner.py b/bdfr/cloner.py index 53108c0..c31f6cc 100644 --- a/bdfr/cloner.py +++ b/bdfr/cloner.py @@ -3,6 +3,7 @@ import logging from time import sleep +from typing import Iterable import prawcore @@ -14,8 +15,8 @@ logger = logging.getLogger(__name__) class RedditCloner(RedditDownloader, Archiver): - def __init__(self, args: Configuration): - super(RedditCloner, self).__init__(args) + def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()): + super(RedditCloner, self).__init__(args, logging_handlers) def download(self): for generator in self.reddit_lists: diff --git a/bdfr/connector.py b/bdfr/connector.py index 860750d..89339f0 100644 --- a/bdfr/connector.py +++ b/bdfr/connector.py @@ -14,7 +14,7 @@ from datetime import datetime from enum import Enum, auto from pathlib import Path from time import sleep -from typing import Callable, Iterator +from typing import Callable, Iterable, Iterator import appdirs import praw @@ -51,20 +51,20 @@ class RedditTypes: class RedditConnector(metaclass=ABCMeta): - def __init__(self, args: Configuration): + def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()): self.args = args self.config_directories = appdirs.AppDirs("bdfr", "BDFR") + self.determine_directories() + self.load_config() + self.read_config() + file_log = self.create_file_logger() + self._apply_logging_handlers(itertools.chain(logging_handlers, [file_log])) self.run_time = datetime.now().isoformat() self._setup_internal_objects() self.reddit_lists = self.retrieve_reddit_lists() def _setup_internal_objects(self): - self.determine_directories() - self.load_config() - self.create_file_logger() - - self.read_config() self.parse_disabled_modules() @@ -94,6 +94,12 @@ class RedditConnector(metaclass=ABCMeta): self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit) self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit} + @staticmethod + def _apply_logging_handlers(handlers: Iterable[logging.Handler]): + main_logger = logging.getLogger() + for handler in handlers: + main_logger.addHandler(handler) + def read_config(self): """Read any cfg values that need to be processed""" if self.args.max_wait_time is None: @@ -203,8 +209,7 @@ class RedditConnector(metaclass=ABCMeta): raise errors.BulkDownloaderException("Could not find a configuration file to load") self.cfg_parser.read(self.config_location) - def create_file_logger(self): - main_logger = logging.getLogger() + def create_file_logger(self) -> logging.handlers.RotatingFileHandler: if self.args.log is None: log_path = Path(self.config_directory, "log_output.txt") else: @@ -229,8 +234,7 @@ class RedditConnector(metaclass=ABCMeta): formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s") file_handler.setFormatter(formatter) file_handler.setLevel(0) - - main_logger.addHandler(file_handler) + return file_handler @staticmethod def sanitise_subreddit_name(subreddit: str) -> str: diff --git a/bdfr/downloader.py b/bdfr/downloader.py index 7ad8a6b..31c839d 100644 --- a/bdfr/downloader.py +++ b/bdfr/downloader.py @@ -9,6 +9,7 @@ from datetime import datetime from multiprocessing import Pool from pathlib import Path from time import sleep +from typing import Iterable import praw import praw.exceptions @@ -36,8 +37,8 @@ def _calc_hash(existing_file: Path): class RedditDownloader(RedditConnector): - def __init__(self, args: Configuration): - super(RedditDownloader, self).__init__(args) + def __init__(self, args: Configuration, logging_handlers: Iterable[logging.Handler] = ()): + super(RedditDownloader, self).__init__(args, logging_handlers) if self.args.search_existing: self.master_hash_list = self.scan_existing_files(self.download_directory) diff --git a/tests/integration_tests/test_download_integration.py b/tests/integration_tests/test_download_integration.py index 4dae353..5d7238a 100644 --- a/tests/integration_tests/test_download_integration.py +++ b/tests/integration_tests/test_download_integration.py @@ -52,9 +52,9 @@ def create_basic_args_for_download_runner(test_args: list[str], run_path: Path): ["-s", "trollxchromosomes", "-L", 3, "--sort", "new"], ["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--sort", "new"], ["-s", "trollxchromosomes", "-L", 3, "--search", "women"], - ["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--search", "women"], + ["-s", "trollxchromosomes", "-L", 3, "--time", "week", "--search", "women"], ["-s", "trollxchromosomes", "-L", 3, "--sort", "new", "--search", "women"], - ["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--sort", "new", "--search", "women"], + ["-s", "trollxchromosomes", "-L", 3, "--time", "week", "--sort", "new", "--search", "women"], ), ) def test_cli_download_subreddits(test_args: list[str], tmp_path: Path): diff --git a/tests/test_connector.py b/tests/test_connector.py index 41c44e3..d300b45 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -254,7 +254,7 @@ def test_get_subreddit_time_verification( for r in results: result_time = datetime.fromtimestamp(r.created_utc) time_diff = nowtime - result_time - assert abs(time_diff - test_delta) < timedelta(minutes=1) + assert time_diff < (test_delta + timedelta(minutes=1)) @pytest.mark.online diff --git a/tests/test_downloader.py b/tests/test_downloader.py index d7aa8dd..b78a81c 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - +import logging import os import re from pathlib import Path @@ -9,12 +9,16 @@ from unittest.mock import MagicMock, patch import praw.models import pytest -from bdfr.__main__ import setup_logging +from bdfr.__main__ import make_console_logging_handler from bdfr.configuration import Configuration from bdfr.connector import RedditConnector from bdfr.downloader import RedditDownloader +def add_console_handler(): + logging.getLogger().addHandler(make_console_logging_handler(3)) + + @pytest.fixture() def args() -> Configuration: args = Configuration() @@ -134,7 +138,7 @@ def test_download_submission_hash_exists( tmp_path: Path, capsys: pytest.CaptureFixture, ): - setup_logging(3) + add_console_handler() downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True downloader_mock.args.folder_scheme = "" @@ -155,7 +159,7 @@ def test_download_submission_hash_exists( def test_download_submission_file_exists( downloader_mock: MagicMock, reddit_instance: praw.Reddit, tmp_path: Path, capsys: pytest.CaptureFixture ): - setup_logging(3) + add_console_handler() downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True downloader_mock.args.folder_scheme = "" @@ -202,7 +206,7 @@ def test_download_submission_min_score_above( tmp_path: Path, capsys: pytest.CaptureFixture, ): - setup_logging(3) + add_console_handler() downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True downloader_mock.args.folder_scheme = "" @@ -226,7 +230,7 @@ def test_download_submission_min_score_below( tmp_path: Path, capsys: pytest.CaptureFixture, ): - setup_logging(3) + add_console_handler() downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True downloader_mock.args.folder_scheme = "" @@ -250,7 +254,7 @@ def test_download_submission_max_score_below( tmp_path: Path, capsys: pytest.CaptureFixture, ): - setup_logging(3) + add_console_handler() downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True downloader_mock.args.folder_scheme = "" @@ -274,7 +278,7 @@ def test_download_submission_max_score_above( tmp_path: Path, capsys: pytest.CaptureFixture, ): - setup_logging(3) + add_console_handler() downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True downloader_mock.args.folder_scheme = ""