Skip to content

Utils

moll.utils.args_support

args_support(deco: Callable)

Decorator to allow a decorator to be used with or without arguments.

Examples:

Decorate a decorator with @args_support

>>> @args_support
... def deco(fn, return_const=10):
...     return lambda: return_const

Now the decorator can be used as @deco

>>> @deco
... def hello_fn():
...     return "hello"
>>> hello_fn()
10

Or as @deco(...)

>>> @deco(return_const=30)
... def goodbye_fn():
...     return "goodbye"
>>> goodbye_fn()
30

moll.utils.cap_vector

cap_vector(vector: Array, max_length: float)

Truncate a vector to a maximum length.

moll.utils.create_key

create_key(seed: Seed = None) -> Array

Create a JAX PRNG key from a seed.

moll.utils.dist_matrix

dist_matrix(points, dist_fn, condensed=False)

Compute pairwise distances between points.

Examples:

>>> points = jnp.array([[0, 0], [1, 0]])
>>> dist_fn = lambda x, y: jnp.linalg.norm(x - y)
>>> dist_matrix(points, dist_fn).tolist()
[[0.0, 1.0], [1.0, 0.0]]
>>> dist_matrix(points, dist_fn, condensed=True).tolist()
[1.0]

moll.utils.dists_to_nearest_neighbor

dists_to_nearest_neighbor(points, dist_fn)

Compute pairwise distances between points.

moll.utils.fill_diagonal

fill_diagonal(array: Array, val: float | int)

moll.utils.filter_

filter_(fn: Callable[..., Iterable], cond: Callable[[Any], bool] | None = None) -> Callable

Decorator to filter iterable items returned by a function by a value.

Examples:

>>> @filter_
... def numbers():
...     return [5, 15, None, 25]
>>> numbers()
[5, 15, 25]
>>> @filter_
... def numbers():
...     yield from [5, 15, None, 25]
>>> numbers()
[5, 15, 25]
>>> @filter_(cond=lambda x: x > 10)
... def numbers():
...     return [5, 15, 20, 25]
>>> numbers()
[15, 20, 25]

moll.utils.fold

fold(vec: NDArray, dim: int, *, dtype: DTypeLike | None = None) -> ndarray

Reduce vector dimension by folding.

Examples:

Fold to a specific size:

>>> fold([1, 0, 1, 0, 0, 0], dim=3)
array([1, 0, 1])

Folding a binary vector returns a binary vector:

>>> fold([True, False, True, False, False, False], dim=2)
array([2, 0])

Specify dtype to change the type of the output:

>>> fold([1, 0, 1, 0, 0, 0], dim=2, dtype=bool)
array([ True, False])

moll.utils.globs

globs(
    centers: Array,
    sizes: Sequence[int] | int = 10,
    stds: Sequence[float] | float = 1,
    seed: Seed = None,
    cap_radius: float | None = None,
    shuffle=True,
)

Generate points around centers.

moll.utils.group_files_by_size

group_files_by_size(
    files: list[Path], max_batches: int, *, sort_size=True, large_first=False
) -> Generator[list[Path], None, None]

Greedily groups files into batches by their size.

Note

Number of batches may be less than max_batches.

moll.utils.iter_lines

iter_lines(
    files: str | Iterable[str],
    skip_rows: int = 0,
    source_fn: Callable[[str], str] | Literal["filename", "stem"] | None = None,
    line_fn: Callable[[str], T] | None | Literal["split"] = None,
) -> Generator[tuple[str, int, str | T | tuple[str]], None, None]

Iterate over lines in files.

moll.utils.iter_precompute

iter_precompute(iterable: Iterable[D], n_precomputed: int) -> Generator[D, None, None]

Function to precompute a number of elements from an iterator.

Examples:

>>> numbers = iter_precompute(range(5), n_precomputed=2)
>>> list(numbers)
[0, 1, 2, 3, 4]
>>> numbers = iter_precompute(range(5), n_precomputed=100)
>>> list(numbers)
[0, 1, 2, 3, 4]
>>> numbers = iter_precompute(range(5), n_precomputed=1)
>>> list(numbers)
[0, 1, 2, 3, 4]

moll.utils.iter_slices

iter_slices(
    data: Iterable[D],
    slice_size: int,
    *,
    collate_fn: Callable[[Iterable[D]], B] = list,
    filter_fn: Callable[[D], bool] | None = None,
    transform_fn: Callable[[B], T] | Literal["transpose"] | None = None
) -> Generator[OneOrMany[B], None, None]

Split an iterable into batches of a given size.

Examples:

>>> list(iter_slices(range(10), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(iter_slices([], 3))
[]
>>> list(iter_slices(range(4), 5, collate_fn=tuple))
[(0, 1, 2, 3)]
>>> list(iter_slices(range(10), 3, collate_fn=sum))
[3, 12, 21, 9]

Filter function can be applied before batching:

>>> list(iter_slices(range(10), 3, filter_fn=lambda n: n % 2 == 0))
[[0, 2, 4], [6, 8]]

If single data item has heterogeneous type, transform_fn="transpose" can be used to split it into batches of homogeneous type:

>>> data = [(1, "one"), (2, "two"), (3, "three"), (4, "four"), (5, "five")]
>>> for num, word in iter_slices(data, 2, transform_fn="transpose"):
...     print(" plus ".join(word), "is", sum(num))
one plus two is 3
three plus four is 7
five is 5

When transform_fn="transpose", tuples of batches are yielded rather than a single batch and collate_fn is applied individually to each batch.

>>> list(iter_slices(data, 2, transform_fn="transpose", collate_fn=list))
[([1, 2], ['one', 'two']), ([3, 4], ['three', 'four']), ([5], ['five'])]

moll.utils.iter_transpose

iter_transpose(
    data: Iterable[D], collate_fn: Callable[[Iterable[D]], R] = tuple
) -> Generator[R, None, None]

Transpose an iterable of iterables.

Examples:

>>> list(iter_transpose([[1, 2, 3], [4, 5, 6]]))
[(1, 4), (2, 5), (3, 6)]
>>> list(iter_transpose([[1, 2, 3], [4, 5, 6]], collate_fn=sum))
[5, 7, 9]

moll.utils.listify

listify(fn: Callable[..., Iterable]) -> Callable

Decorator to convert a generator function into a list-returning function.

Examples:

>>> @listify
... def numbers():
...     yield from range(5)
>>> numbers()
[0, 1, 2, 3, 4]
>>> @listify
... def empty():
...     if False:
...         yield
>>> empty()
[]

moll.utils.map_concurrently

map_concurrently(
    fn: Callable[[D], R],
    data: Iterable[D],
    *,
    n_workers: int | None = None,
    proc: bool = False,
    exception_fn: Callable[[Exception], Any] | Literal["ignore", "raise"] | None = "raise",
    buffer_size=150000
) -> Generator[R | None, None, None]

Apply a function to each item in an iterable in parallel.

Examples:

>>> def square(x):
...     return x**2
>>> list(map_concurrently(square, range(10)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

By default, exceptions are raised:

>>> bad_fn = lambda x: (x - 2) / (x - 2) * x
>>> accumulator = []
>>> for result in map_concurrently(bad_fn, range(5)):
...     accumulator.append(result)
Traceback (most recent call last):
    ...
ZeroDivisionError: division by zero

All computations before the exception are always returned:

>>> accumulator
[0.0, 1.0]

Exceptions can be easily ignored:

>>> list(map_concurrently(bad_fn, range(5), exception_fn="ignore"))
[0.0, 1.0, 3.0, 4.0]

If exception_fn=None, None is yielded instead of the result:

>>> list(map_concurrently(bad_fn, range(5), exception_fn=None))
[0.0, 1.0, None, 3.0, 4.0]

Exceptions can be handled in a custom way:

>>> def const(e):
...     return 42
>>> list(map_concurrently(bad_fn, range(5), exception_fn=const))
[0.0, 1.0, 42, 3.0, 4.0]

By default, the number of workers is equal to the number of CPU cores, multithreading is used. Set proc=True to enable multiprocessing:

>>> from math import factorial
>>> list(map_concurrently(factorial, range(10, 15), proc=True))
[3628800, 39916800, 479001600, 6227020800, 87178291200]

moll.utils.matrix_cross_sum

matrix_cross_sum(X: Array, i: int, j: int, row_only=False, crossover=True)

Compute the sum of the elements in the row i and the column j of the matrix X.

moll.utils.no_exceptions

no_exceptions(
    fn: Callable, exceptions: OneOrMany[type[BaseException]] = Exception, default: Any = None
) -> Callable

Decorator to catch exceptions and return a default value instead.

Examples:

>>> @no_exceptions(default="Error occurred")
... def bad_fn(x):
...     return x / 0
>>> bad_fn(10)
'Error occurred'
>>> @no_exceptions(exceptions=ZeroDivisionError)
... def bad_fn(x):
...     return x / 0
>>> bad_fn(10)
>>> @no_exceptions(exceptions=TypeError)
... def bad_fn(x):
...     return x / 0
>>> bad_fn(10)
Traceback (most recent call last):
    ...
ZeroDivisionError: division by zero

moll.utils.no_warnings

no_warnings(fn: Callable, suppress_rdkit=True) -> Callable

Decorator to suppress warnings in a function.

Examples:

>>> import warnings
>>> @no_warnings
... def warn():
...     warnings.warn("Boooo!!!", UserWarning)
>>> warn()
>>> from rdkit import Chem
>>> @no_warnings
... def warn_rdkit():
...     Chem.MolFromSmiles("C1=CC=CC=C1O")
>>> warn_rdkit()

moll.utils.partition

partition(data: list, *, n_partitions: int)

Partition the data into n_partitions partitions.

moll.utils.points_around

points_around(
    center: Array, n_points: int, std: float = 1, cap_radius: float | None = None, seed: Seed = None
)

Generate points around a center.

moll.utils.random_grid_points

random_grid_points(n_points: int, dim: int, n_ticks: int, spacing: int = 1, seed: Seed = None)

Generate random grid points.

moll.utils.time_int_seed

time_int_seed() -> int

Returns a number of microseconds since the epoch.

moll.utils.time_key

time_key() -> Array

Returns a JAX PRNG key based on the current time.

moll.utils.unpack_arguments

unpack_arguments(fn) -> Callable[..., Any]

Decorator to unpack arguments.

Examples:

>>> @unpack_arguments
... def add(x, y, z=0):
...     return x + y + z
>>> add([1, 2])
3
>>> add((1, 2, 3))
6