Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion beets/autotag/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def _eq(self, value1: re.Pattern[str] | Any, value2: Any) -> bool:
matched against `value2`.
"""
if isinstance(value1, re.Pattern):
value2 = cast(str, value2)
value2 = cast("str", value2)
return bool(value1.match(value2))
return value1 == value2

Expand Down
15 changes: 8 additions & 7 deletions beets/autotag/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import datetime
import re
from collections.abc import Iterable, Sequence
from enum import IntEnum
from functools import cache
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast
Expand All @@ -40,6 +39,8 @@
from beets.util import plurality

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence

from beets.library import Item

# Artist signals that indicate "various artists". These are used at the
Expand Down Expand Up @@ -245,7 +246,7 @@ def distance(
if album_info.media:
# Preferred media options.
patterns = config["match"]["preferred"]["media"].as_str_seq()
patterns = cast(Sequence[str], patterns)
patterns = cast("Sequence[str]", patterns)
options = [re.compile(r"(\d+x)?(%s)" % pat, re.I) for pat in patterns]
if options:
dist.add_priority("media", album_info.media, options)
Expand Down Expand Up @@ -283,7 +284,7 @@ def distance(

# Preferred countries.
patterns = config["match"]["preferred"]["countries"].as_str_seq()
patterns = cast(Sequence[str], patterns)
patterns = cast("Sequence[str]", patterns)
options = [re.compile(pat, re.I) for pat in patterns]
if album_info.country and options:
dist.add_priority("country", album_info.country, options)
Expand Down Expand Up @@ -448,7 +449,7 @@ def _add_candidate(

# Discard matches without required tags.
for req_tag in cast(
Sequence[str], config["match"]["required"].as_str_seq()
"Sequence[str]", config["match"]["required"].as_str_seq()
):
if getattr(info, req_tag) is None:
log.debug("Ignored. Missing required tag: {0}", req_tag)
Expand All @@ -462,7 +463,7 @@ def _add_candidate(

# Skip matches with ignored penalties.
penalties = [key for key, _ in dist]
ignored = cast(Sequence[str], config["match"]["ignored"].as_str_seq())
ignored = cast("Sequence[str]", config["match"]["ignored"].as_str_seq())
for penalty in ignored:
if penalty in penalties:
log.debug("Ignored. Penalty: {0}", penalty)
Expand Down Expand Up @@ -499,8 +500,8 @@ def tag_album(
"""
# Get current metadata.
likelies, consensus = current_metadata(items)
cur_artist = cast(str, likelies["artist"])
cur_album = cast(str, likelies["album"])
cur_artist = cast("str", likelies["artist"])
cur_album = cast("str", likelies["album"])
log.debug("Tagging {0} - {1}", cur_artist, cur_album)

# The output result, keys are the MB album ID.
Expand Down
13 changes: 9 additions & 4 deletions beets/autotag/mb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
import re
import traceback
from collections import Counter
from collections.abc import Iterator, Sequence
from itertools import product
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urljoin

import musicbrainzngs
Expand All @@ -37,6 +36,9 @@
spotify_id_regex,
)

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence

VARIOUS_ARTISTS_ID = "89ad4ac3-39f7-470e-963a-56509c546377"

BASE_URL = "https://musicbrainz.org/"
Expand Down Expand Up @@ -184,7 +186,7 @@ def _preferred_release_event(release: dict[str, Any]) -> tuple[str, str]:
default release event if a preferred event is not found.
"""
countries = config["match"]["preferred"]["countries"].as_str_seq()
countries = cast(Sequence, countries)
countries = cast("Sequence", countries)

for country in countries:
for event in release.get("release-event-list", {}):
Expand All @@ -194,7 +196,10 @@ def _preferred_release_event(release: dict[str, Any]) -> tuple[str, str]:
except KeyError:
pass

return (cast(str, release.get("country")), cast(str, release.get("date")))
return (
cast("str", release.get("country")),
cast("str", release.get("date")),
)


def _multi_artist_credit(
Expand Down
4 changes: 2 additions & 2 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def _get_formatted(self, model: Model, key: str) -> str:
value = value.decode("utf-8", "ignore")

if self.for_path:
sep_repl = cast(str, beets.config["path_sep_replace"].as_str())
sep_drive = cast(str, beets.config["drive_sep_replace"].as_str())
sep_repl = cast("str", beets.config["path_sep_replace"].as_str())
sep_drive = cast("str", beets.config["drive_sep_replace"].as_str())

if re.match(r"^\w:", value):
value = re.sub(r"(?<=^\w):", sep_drive, value)
Expand Down
4 changes: 2 additions & 2 deletions beets/dbcore/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def null(self) -> N:
# Note that this default implementation only makes sense for T = N.
# It would be better to implement `null()` only in subclasses, or
# have a field null_type similar to `model_type` and use that here.
return cast(N, self.model_type())
return cast("N", self.model_type())

def format(self, value: N | T) -> str:
"""Given a value of this type, produce a Unicode string
Expand Down Expand Up @@ -115,7 +115,7 @@ def normalize(self, value: Any) -> T | N:
else:
# TODO This should eventually be replaced by
# `self.model_type(value)`
return cast(T, value)
return cast("T", value)

def from_sql(self, sql_value: SQLiteType) -> T | N:
"""Receives the value stored in the SQL backend and return the
Expand Down
47 changes: 41 additions & 6 deletions beets/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,35 @@ def _setup_logging(self, loghandler: logging.Handler | None):
logger.handlers = [loghandler]
return logger

def tasks_created(self, _: list[ImportTask]) -> None:
"""Called when a list of tasks is created.

Expected to be called when an individual directory or query result is
transformed into a list of tasks.
"""
raise NotImplementedError

def task_candidates_found(self) -> None:
"""Called when a task has found candidates.

Expected to be called by an ImportTask when it has found candidates.
"""
raise NotImplementedError

def task_match_chosen(self) -> None:
"""Called when a task has chosen a match.

Expected to be called by an ImportTask when it has chosen a match.
"""
raise NotImplementedError

def task_finalized(self) -> None:
"""Called when a task has been finalized.

Expected to be called by an ImportTask when it has been finalized.
"""
raise NotImplementedError

def set_config(self, config):
"""Set `config` property from global import config and make
implied changes.
Expand Down Expand Up @@ -610,7 +639,7 @@ def chosen_info(self):
return likelies
elif self.choice_flag is action.APPLY and self.match:
return self.match.info.copy()
assert False
raise ValueError("Invalid choice flag; this should never happen.")

def imported_items(self):
"""Return a list of Items that should be added to the library.
Expand All @@ -625,7 +654,7 @@ def imported_items(self):
):
return list(self.match.mapping.keys())
else:
assert False
raise ValueError("Invalid choice flag; this should never happen.")

def apply_metadata(self):
"""Copy metadata from match info to the items."""
Expand Down Expand Up @@ -694,6 +723,8 @@ def finalize(self, session: ImportSession):
if not self.skip:
self._emit_imported(session.lib)

session.task_finalized()

def cleanup(self, copy=False, delete=False, move=False):
"""Remove and prune imported paths."""
# Do not delete any files or prune directories when skipping.
Expand Down Expand Up @@ -731,9 +762,10 @@ def handle_created(self, session: ImportSession):
else:
# The plugins gave us a list of lists of tasks. Flatten it.
tasks = [t for inner in tasks for t in inner]
session.tasks_created(tasks)
return tasks

def lookup_candidates(self):
def lookup_candidates(self, session: ImportSession):
"""Retrieve and store candidates for this album. User-specified
candidate IDs are stored in self.search_ids: if present, the
initial lookup is restricted to only those IDs.
Expand All @@ -745,6 +777,7 @@ def lookup_candidates(self):
self.cur_album = album
self.candidates = prop.candidates
self.rec = prop.recommendation
session.task_candidates_found()

def find_duplicates(self, lib: library.Library):
"""Return a list of albums from `lib` with the same artist and
Expand Down Expand Up @@ -1017,6 +1050,7 @@ def choose_match(self, session):
choice = session.choose_match(self)
self.set_choice(choice)
session.log_choice(self)
session.task_match_chosen()

def reload(self):
"""Reload albums and items from the database."""
Expand Down Expand Up @@ -1072,10 +1106,11 @@ def _emit_imported(self, lib):
for item in self.imported_items():
plugins.send("item_imported", lib=lib, item=item)

def lookup_candidates(self):
def lookup_candidates(self, session: ImportSession):
prop = autotag.tag_item(self.item, search_ids=self.search_ids)
self.candidates = prop.candidates
self.rec = prop.recommendation
session.task_candidates_found()

def find_duplicates(self, lib):
"""Return a list of items from `lib` that have the same artist
Expand Down Expand Up @@ -1516,7 +1551,7 @@ def read_tasks(session: ImportSession):
log.info("Skipped {0} paths.", skipped)


def query_tasks(session: ImportSession):
def query_tasks(session: ImportSession) -> Iterable[ImportTask]:
"""A generator that works as a drop-in-replacement for read_tasks.
Instead of finding files from the filesystem, a query is used to
match items from the library.
Expand Down Expand Up @@ -1564,7 +1599,7 @@ def lookup_candidates(session: ImportSession, task: ImportTask):
# option. Currently all the IDs are passed onto the tasks directly.
task.search_ids = session.config["search_ids"].as_str_seq()

task.lookup_candidates()
task.lookup_candidates(session)


@pipeline.stage
Expand Down
12 changes: 6 additions & 6 deletions beets/test/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ def assertNotExists(self, path):

def assertIsFile(self, path):
self.assertExists(path)
assert os.path.isfile(
syspath(path)
), "path exists, but is not a regular file: {!r}".format(path)
assert os.path.isfile(syspath(path)), (
"path exists, but is not a regular file: {!r}".format(path)
)

def assertIsDir(self, path):
self.assertExists(path)
assert os.path.isdir(
syspath(path)
), "path exists, but is not a directory: {!r}".format(path)
assert os.path.isdir(syspath(path)), (
"path exists, but is not a directory: {!r}".format(path)
)

def assert_equal_path(self, a, b):
"""Check that two paths are equal."""
Expand Down
34 changes: 34 additions & 0 deletions beets/test/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,13 +679,30 @@ class ImportSessionFixture(ImportSession):
remaining albums, the metadata from the autotagger will be applied.
"""

created: int = 0
candidates_found: int = 0
match_chosen: int = 0
finalized: int = 0

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._choices = []
self._resolutions = []

default_choice = importer.action.APPLY

def tasks_created(self, tasks: list[importer.ImportTask]) -> None:
self.created += len(tasks)

def task_candidates_found(self) -> None:
self.candidates_found += 1

def task_match_chosen(self) -> None:
self.match_chosen += 1

def task_finalized(self) -> None:
self.finalized += 1

def add_choice(self, choice):
self._choices.append(choice)

Expand Down Expand Up @@ -726,13 +743,30 @@ def resolve_duplicate(self, task, found_duplicates):


class TerminalImportSessionFixture(TerminalImportSession):
created: int = 0
candidates_found: int = 0
match_chosen: int = 0
finalized: int = 0

def __init__(self, *args, **kwargs):
self.io = kwargs.pop("io")
super().__init__(*args, **kwargs)
self._choices = []

default_choice = importer.action.APPLY

def tasks_created(self, tasks: list[importer.ImportTask]) -> None:
self.created += len(tasks)

def task_candidates_found(self) -> None:
self.candidates_found += 1

def task_match_chosen(self) -> None:
self.match_chosen += 1

def task_finalized(self) -> None:
self.finalized += 1

def add_choice(self, choice):
self._choices.append(choice)

Expand Down
Loading