Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions atomgen/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def torch_call(
# Handle dict or lists with proper padding and conversion to tensor.
if self.pad:
if isinstance(examples[0], Mapping):
batch: Dict[str, Any] = self.tokenizer.pad( # type: ignore[assignment]
examples, # type: ignore[arg-type]
batch: Dict[str, Any] = self.tokenizer.pad(
examples,
return_tensors="pt",
pad_to_multiple_of=self.pad_to_multiple_of,
)
Expand Down Expand Up @@ -186,7 +186,7 @@ def torch_mask_tokens(
inputs = torch.where(
~mask,
inputs,
self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token), # type: ignore[arg-type]
self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token),
)
labels = torch.where(mask, labels, -100)
if special_tokens_mask is not None:
Expand Down
10 changes: 5 additions & 5 deletions atomgen/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
VOCAB_FILES_NAMES: Dict[str, str] = {"vocab_file": "tokenizer.json"}


class AtomTokenizer(PreTrainedTokenizer):
class AtomTokenizer(PreTrainedTokenizer): # type: ignore[misc]
"""
Tokenizer for atomistic data.

Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
[(ids, tok) for tok, ids in self.vocab.items()]
)

super().__init__( # type: ignore[no-untyped-call]
super().__init__(
pad_token=pad_token,
mask_token=mask_token,
bos_token=bos_token,
Expand All @@ -63,7 +63,7 @@ def load_vocab(vocab_file: str) -> Dict[str, int]:
)
return vocab

def _tokenize(self, text: str) -> List[str]: # type: ignore[override]
def _tokenize(self, text: str) -> List[str]:
"""Tokenize the text."""
tokens = []
i = 0
Expand Down Expand Up @@ -95,7 +95,7 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert the list of chemical symbol tokens to a concatenated string."""
return "".join(tokens)

def pad( # type: ignore[override]
def pad(
self,
encoded_inputs: Union[
BatchEncoding,
Expand Down Expand Up @@ -155,7 +155,7 @@ def pad( # type: ignore[override]
pad_to_multiple_of=pad_to_multiple_of,
)

return super().pad(
return super().pad( # type: ignore[no-any-return]
encoded_inputs=encoded_inputs,
padding=padding,
max_length=max_length,
Expand Down
2 changes: 1 addition & 1 deletion atomgen/models/modeling_atomformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,7 +2550,7 @@ def forward(
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
"""Base class for all transformer models."""

config_class = AtomformerConfig
config_class = AtomformerConfig # type: ignore[assignment]
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ParallelBlock"]
Expand Down
2 changes: 1 addition & 1 deletion atomgen/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class SchNetPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
simple interface for loading and exporting models.
"""

config_class = SchNetConfig
config_class = SchNetConfig # type: ignore[assignment]
base_model_prefix = "model"
supports_gradient_checkpointing = False

Expand Down
2 changes: 1 addition & 1 deletion atomgen/models/tokengt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,7 +2510,7 @@ def custom_forward(*inputs: Any) -> Any:
class TransformerPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
"""Base class for all transformer models."""

config_class = TransformerConfig
config_class = TransformerConfig # type: ignore[assignment]
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ParallelBlock"]
Expand Down
Loading