bulk-downloader-for-reddit/tests/test_connector.py

525 lines
18 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
2021-12-09 11:04:11 +08:00
from datetime import datetime, timedelta
from pathlib import Path
from typing import Iterator
from unittest.mock import MagicMock
import praw
import praw.models
import pytest
from bdfr.configuration import Configuration
from bdfr.connector import RedditConnector, RedditTypes
from bdfr.download_filter import DownloadFilter
from bdfr.exceptions import BulkDownloaderException
from bdfr.file_name_formatter import FileNameFormatter
from bdfr.site_authenticator import SiteAuthenticator
@pytest.fixture()
def args() -> Configuration:
args = Configuration()
2022-12-03 13:11:17 +08:00
args.time_format = "ISO"
return args
@pytest.fixture()
def downloader_mock(args: Configuration):
downloader_mock = MagicMock()
downloader_mock.args = args
downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name
2021-06-30 10:52:27 +08:00
downloader_mock.create_filtered_listing_generator = lambda x: RedditConnector.create_filtered_listing_generator(
2022-12-03 13:11:17 +08:00
downloader_mock, x
)
downloader_mock.split_args_input = RedditConnector.split_args_input
downloader_mock.master_hash_list = {}
return downloader_mock
def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list:
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
2021-06-30 10:52:27 +08:00
assert not any([isinstance(m, MagicMock) for m in results])
if result_limit is not None:
assert len(results) == result_limit
return results
2021-07-18 12:42:10 +08:00
def assert_all_results_are_submissions_or_comments(result_limit: int, results: list[Iterator]) -> list:
results = [sub for res in results for sub in res]
assert all([isinstance(res, (praw.models.Submission, praw.models.Comment)) for res in results])
2021-07-18 12:42:10 +08:00
assert not any([isinstance(m, MagicMock) for m in results])
if result_limit is not None:
assert len(results) == result_limit
return results
def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
2022-12-03 13:11:17 +08:00
downloader_mock.args.directory = tmp_path / "test"
downloader_mock.config_directories.user_config_dir = tmp_path
RedditConnector.determine_directories(downloader_mock)
2022-12-03 13:11:17 +08:00
assert Path(tmp_path / "test").exists()
@pytest.mark.parametrize(
("skip_extensions", "skip_domains"),
(
([], []),
(
[".test"],
["test.com"],
),
),
)
def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
2021-05-21 05:14:35 +08:00
downloader_mock.args.skip = skip_extensions
downloader_mock.args.skip_domain = skip_domains
result = RedditConnector.create_download_filter(downloader_mock)
assert isinstance(result, DownloadFilter)
assert result.excluded_domains == skip_domains
assert result.excluded_extensions == skip_extensions
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_time", "expected"),
(
("all", "all"),
("hour", "hour"),
("day", "day"),
("week", "week"),
("random", "all"),
("", "all"),
),
)
def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.time = test_time
result = RedditConnector.create_time_filter(downloader_mock)
assert isinstance(result, RedditTypes.TimeType)
assert result.name.lower() == expected
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_sort", "expected"),
(
("", "hot"),
("hot", "hot"),
("controversial", "controversial"),
("new", "new"),
),
)
def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
downloader_mock.args.sort = test_sort
result = RedditConnector.create_sort_filter(downloader_mock)
assert isinstance(result, RedditTypes.SortType)
assert result.name.lower() == expected
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_file_scheme", "test_folder_scheme"),
(
("{POSTID}", "{SUBREDDIT}"),
("{REDDITOR}_{TITLE}_{POSTID}", "{SUBREDDIT}"),
("{POSTID}", "test"),
("{POSTID}", ""),
("{POSTID}", "{SUBREDDIT}/{REDDITOR}"),
),
)
def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme
result = RedditConnector.create_file_name_formatter(downloader_mock)
assert isinstance(result, FileNameFormatter)
assert result.file_format_string == test_file_scheme
2022-12-03 13:11:17 +08:00
assert result.directory_format_string == test_folder_scheme.split("/")
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_file_scheme", "test_folder_scheme"),
(
("", ""),
("", "{SUBREDDIT}"),
("test", "{SUBREDDIT}"),
),
)
def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
downloader_mock.args.file_scheme = test_file_scheme
downloader_mock.args.folder_scheme = test_folder_scheme
with pytest.raises(BulkDownloaderException):
RedditConnector.create_file_name_formatter(downloader_mock)
def test_create_authenticator(downloader_mock: MagicMock):
result = RedditConnector.create_authenticator(downloader_mock)
assert isinstance(result, SiteAuthenticator)
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
"test_submission_ids",
(
("lvpf4l",),
("lvpf4l", "lvqnsn"),
("lvpf4l", "lvqnsn", "lvl9kd"),
2023-01-01 16:46:08 +08:00
("1000000",),
2022-12-03 13:11:17 +08:00
),
)
def test_get_submissions_from_link(
2022-12-03 13:11:17 +08:00
test_submission_ids: list[str], reddit_instance: praw.Reddit, downloader_mock: MagicMock
):
downloader_mock.args.link = test_submission_ids
downloader_mock.reddit_instance = reddit_instance
results = RedditConnector.get_submissions_from_link(downloader_mock)
assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
assert len(results[0]) == len(test_submission_ids)
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_subreddits", "limit", "sort_type", "time_filter", "max_expected_len"),
(
(("Futurology",), 10, "hot", "all", 10),
(("Futurology", "Mindustry, Python"), 10, "hot", "all", 30),
(("Futurology",), 20, "hot", "all", 20),
(("Futurology", "Python"), 10, "hot", "all", 20),
(("Futurology",), 100, "hot", "all", 100),
(("Futurology",), 0, "hot", "all", 0),
(("Futurology",), 10, "top", "all", 10),
(("Futurology",), 10, "top", "week", 10),
(("Futurology",), 10, "hot", "week", 10),
),
)
def test_get_subreddit_normal(
2022-12-03 13:11:17 +08:00
test_subreddits: list[str],
limit: int,
sort_type: str,
time_filter: str,
max_expected_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
):
downloader_mock.args.limit = limit
downloader_mock.args.sort = sort_type
2021-06-30 10:52:27 +08:00
downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock)
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
results = RedditConnector.get_subreddits(downloader_mock)
2021-06-30 10:52:27 +08:00
test_subreddits = downloader_mock.split_args_input(test_subreddits)
results = [sub for res1 in results for sub in res1]
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
assert all([res.subreddit.display_name in test_subreddits for res in results])
assert len(results) <= max_expected_len
2021-06-30 10:52:27 +08:00
assert not any([isinstance(m, MagicMock) for m in results])
2021-12-09 11:04:11 +08:00
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_time", "test_delta"),
(
("hour", timedelta(hours=1)),
("day", timedelta(days=1)),
("week", timedelta(days=7)),
("month", timedelta(days=31)),
("year", timedelta(days=365)),
),
)
2021-12-09 11:04:11 +08:00
def test_get_subreddit_time_verification(
2022-12-03 13:11:17 +08:00
test_time: str,
test_delta: timedelta,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
2021-12-09 11:04:11 +08:00
):
downloader_mock.args.limit = 10
2022-12-03 13:11:17 +08:00
downloader_mock.args.sort = "top"
2021-12-09 11:04:11 +08:00
downloader_mock.args.time = test_time
downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock)
2022-12-03 13:11:17 +08:00
downloader_mock.args.subreddit = ["all"]
2021-12-09 11:04:11 +08:00
downloader_mock.reddit_instance = reddit_instance
results = RedditConnector.get_subreddits(downloader_mock)
results = [sub for res1 in results for sub in res1]
assert all([isinstance(res1, praw.models.Submission) for res1 in results])
nowtime = datetime.now()
for r in results:
result_time = datetime.fromtimestamp(r.created_utc)
time_diff = nowtime - result_time
assert time_diff < test_delta
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_subreddits", "search_term", "limit", "time_filter", "max_expected_len"),
(
(("Python",), "scraper", 10, "all", 10),
(("Python",), "", 10, "all", 0),
(("Python",), "djsdsgewef", 10, "all", 0),
(("Python",), "scraper", 10, "year", 10),
),
)
def test_get_subreddit_search(
2022-12-03 13:11:17 +08:00
test_subreddits: list[str],
search_term: str,
time_filter: str,
limit: int,
max_expected_len: int,
downloader_mock: MagicMock,
reddit_instance: praw.Reddit,
):
downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.args.limit = limit
downloader_mock.args.search = search_term
downloader_mock.args.subreddit = test_subreddits
downloader_mock.reddit_instance = reddit_instance
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.time = time_filter
downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
results = RedditConnector.get_subreddits(downloader_mock)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert all([res.subreddit.display_name in test_subreddits for res in results])
assert len(results) <= max_expected_len
2021-09-09 11:42:18 +08:00
if max_expected_len != 0:
assert results
2021-06-30 10:52:27 +08:00
assert not any([isinstance(m, MagicMock) for m in results])
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_user", "test_multireddits", "limit"),
(
("helen_darten", ("cuteanimalpics",), 10),
("korfor", ("chess",), 100),
),
)
# Good sources at https://www.reddit.com/r/multihub/
def test_get_multireddits_public(
2022-12-03 13:11:17 +08:00
test_user: str,
test_multireddits: list[str],
limit: int,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.limit = limit
downloader_mock.args.multireddit = test_multireddits
downloader_mock.args.user = [test_user]
downloader_mock.reddit_instance = reddit_instance
2022-12-03 13:11:17 +08:00
downloader_mock.create_filtered_listing_generator.return_value = RedditConnector.create_filtered_listing_generator(
downloader_mock,
reddit_instance.multireddit(redditor=test_user, name=test_multireddits[0]),
)
results = RedditConnector.get_multireddits(downloader_mock)
results = [sub for res in results for sub in res]
assert all([isinstance(res, praw.models.Submission) for res in results])
assert len(results) == limit
2021-06-30 10:52:27 +08:00
assert not any([isinstance(m, MagicMock) for m in results])
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_user", "limit"),
(
("danigirl3694", 10),
("danigirl3694", 50),
("CapitanHam", None),
),
)
def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
downloader_mock.args.limit = limit
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
downloader_mock.args.submitted = True
downloader_mock.args.user = [test_user]
downloader_mock.authenticated = False
downloader_mock.reddit_instance = reddit_instance
2022-12-03 13:11:17 +08:00
downloader_mock.create_filtered_listing_generator.return_value = RedditConnector.create_filtered_listing_generator(
downloader_mock,
reddit_instance.redditor(test_user).submissions,
)
results = RedditConnector.get_user_data(downloader_mock)
results = assert_all_results_are_submissions(limit, results)
assert all([res.author.name == test_user for res in results])
2021-06-30 10:52:27 +08:00
assert not any([isinstance(m, MagicMock) for m in results])
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
"test_flag",
(
"upvoted",
"saved",
),
)
def test_get_user_authenticated_lists(
2022-12-03 13:11:17 +08:00
test_flag: str,
downloader_mock: MagicMock,
authenticated_reddit_instance: praw.Reddit,
):
downloader_mock.args.__dict__[test_flag] = True
downloader_mock.reddit_instance = authenticated_reddit_instance
downloader_mock.args.limit = 10
2022-02-18 08:21:52 +08:00
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
2022-12-03 13:11:17 +08:00
downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, "me")]
results = RedditConnector.get_user_data(downloader_mock)
2021-07-18 12:42:10 +08:00
assert_all_results_are_submissions_or_comments(10, results)
2022-02-18 08:21:52 +08:00
@pytest.mark.online
@pytest.mark.reddit
@pytest.mark.authenticated
def test_get_subscribed_subreddits(downloader_mock: MagicMock, authenticated_reddit_instance: praw.Reddit):
downloader_mock.reddit_instance = authenticated_reddit_instance
downloader_mock.args.limit = 10
downloader_mock.args.authenticate = True
downloader_mock.args.subscribed = True
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
downloader_mock.sort_filter = RedditTypes.SortType.HOT
results = RedditConnector.get_subreddits(downloader_mock)
assert all([isinstance(s, praw.models.ListingGenerator) for s in results])
assert results
2022-02-18 08:21:52 +08:00
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_name", "expected"),
(
("Mindustry", "Mindustry"),
("Futurology", "Futurology"),
("r/Mindustry", "Mindustry"),
("TrollXChromosomes", "TrollXChromosomes"),
("r/TrollXChromosomes", "TrollXChromosomes"),
("https://www.reddit.com/r/TrollXChromosomes/", "TrollXChromosomes"),
("https://www.reddit.com/r/TrollXChromosomes", "TrollXChromosomes"),
("https://www.reddit.com/r/Futurology/", "Futurology"),
("https://www.reddit.com/r/Futurology", "Futurology"),
),
)
def test_sanitise_subreddit_name(test_name: str, expected: str):
result = RedditConnector.sanitise_subreddit_name(test_name)
assert result == expected
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_subreddit_entries", "expected"),
(
(["test1", "test2", "test3"], {"test1", "test2", "test3"}),
(["test1,test2", "test3"], {"test1", "test2", "test3"}),
(["test1, test2", "test3"], {"test1", "test2", "test3"}),
(["test1; test2", "test3"], {"test1", "test2", "test3"}),
(["test1, test2", "test1,test2,test3", "test4"], {"test1", "test2", "test3", "test4"}),
([""], {""}),
(["test"], {"test"}),
),
)
def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
results = RedditConnector.split_args_input(test_subreddit_entries)
assert results == expected
2021-07-05 14:58:33 +08:00
def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
2022-12-03 13:11:17 +08:00
test_file = tmp_path / "test.txt"
test_file.write_text("aaaaaa\nbbbbbb")
2021-07-05 14:58:33 +08:00
results = RedditConnector.read_id_files([str(test_file)])
2022-12-03 13:11:17 +08:00
assert results == {"aaaaaa", "bbbbbb"}
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
"test_redditor_name",
(
"nasa",
"crowdstrike",
"HannibalGoddamnit",
),
)
def test_check_user_existence_good(
2022-12-03 13:11:17 +08:00
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
"test_redditor_name",
(
"lhnhfkuhwreolo",
"adlkfmnhglojh",
),
)
def test_check_user_existence_nonexistent(
2022-12-03 13:11:17 +08:00
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
2022-12-03 13:11:17 +08:00
with pytest.raises(BulkDownloaderException, match="Could not find"):
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize("test_redditor_name", ("Bree-Boo",))
def test_check_user_existence_banned(
2022-12-03 13:11:17 +08:00
test_redditor_name: str,
reddit_instance: praw.Reddit,
downloader_mock: MagicMock,
):
downloader_mock.reddit_instance = reddit_instance
2022-12-03 13:11:17 +08:00
with pytest.raises(BulkDownloaderException, match="is banned"):
RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
("test_subreddit_name", "expected_message"),
(
("donaldtrump", "cannot be found"),
("submitters", "private and cannot be scraped"),
("lhnhfkuhwreolo", "does not exist"),
),
)
def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit):
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
with pytest.raises(BulkDownloaderException, match=expected_message):
RedditConnector.check_subreddit_status(test_subreddit)
@pytest.mark.online
@pytest.mark.reddit
2022-12-03 13:11:17 +08:00
@pytest.mark.parametrize(
"test_subreddit_name",
(
"Python",
"Mindustry",
"TrollXChromosomes",
"all",
),
)
def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit):
test_subreddit = reddit_instance.subreddit(test_subreddit_name)
RedditConnector.check_subreddit_status(test_subreddit)