Utils
moll.utils.args_support
args_support(deco: Callable)
Decorator to allow a decorator to be used with or without arguments.
Examples:
Decorate a decorator with @args_support
>>> @args_support
... def deco(fn, return_const=10):
... return lambda: return_const
Now the decorator can be used as @deco
>>> @deco
... def hello_fn():
... return "hello"
>>> hello_fn()
10
Or as @deco(...)
>>> @deco(return_const=30)
... def goodbye_fn():
... return "goodbye"
>>> goodbye_fn()
30
moll.utils.cap_vector
cap_vector(vector: Array, max_length: float)
Truncate a vector to a maximum length.
moll.utils.create_key
create_key(seed: Seed = None) -> Array
Create a JAX PRNG key from a seed.
moll.utils.dist_matrix
dist_matrix(points, dist_fn, condensed=False)
Compute pairwise distances between points.
Examples:
>>> points = jnp.array([[0, 0], [1, 0]])
>>> dist_fn = lambda x, y: jnp.linalg.norm(x - y)
>>> dist_matrix(points, dist_fn).tolist()
[[0.0, 1.0], [1.0, 0.0]]
>>> dist_matrix(points, dist_fn, condensed=True).tolist()
[1.0]
moll.utils.dists_to_nearest_neighbor
dists_to_nearest_neighbor(points, dist_fn)
Compute pairwise distances between points.
moll.utils.fill_diagonal
fill_diagonal(array: Array, val: float | int)
moll.utils.filter_
filter_(fn: Callable[..., Iterable], cond: Callable[[Any], bool] | None = None) -> Callable
Decorator to filter iterable items returned by a function by a value.
Examples:
>>> @filter_
... def numbers():
... return [5, 15, None, 25]
>>> numbers()
[5, 15, 25]
>>> @filter_
... def numbers():
... yield from [5, 15, None, 25]
>>> numbers()
[5, 15, 25]
>>> @filter_(cond=lambda x: x > 10)
... def numbers():
... return [5, 15, 20, 25]
>>> numbers()
[15, 20, 25]
moll.utils.fold
fold(vec: NDArray, dim: int, *, dtype: DTypeLike | None = None) -> ndarray
Reduce vector dimension by folding.
Examples:
Fold to a specific size:
>>> fold([1, 0, 1, 0, 0, 0], dim=3)
array([1, 0, 1])
Folding a binary vector returns a binary vector:
>>> fold([True, False, True, False, False, False], dim=2)
array([2, 0])
Specify dtype
to change the type of the output:
>>> fold([1, 0, 1, 0, 0, 0], dim=2, dtype=bool)
array([ True, False])
moll.utils.globs
globs(
centers: Array,
sizes: Sequence[int] | int = 10,
stds: Sequence[float] | float = 1,
seed: Seed = None,
cap_radius: float | None = None,
shuffle=True,
)
Generate points around centers.
moll.utils.group_files_by_size
group_files_by_size(
files: list[Path], max_batches: int, *, sort_size=True, large_first=False
) -> Generator[list[Path], None, None]
Greedily groups files into batches by their size.
Note
Number of batches may be less than max_batches
.
moll.utils.iter_lines
iter_lines(
files: str | Iterable[str],
skip_rows: int = 0,
source_fn: Callable[[str], str] | Literal["filename", "stem"] | None = None,
line_fn: Callable[[str], T] | None | Literal["split"] = None,
) -> Generator[tuple[str, int, str | T | tuple[str]], None, None]
Iterate over lines in files.
moll.utils.iter_precompute
iter_precompute(iterable: Iterable[D], n_precomputed: int) -> Generator[D, None, None]
Function to precompute a number of elements from an iterator.
Examples:
>>> numbers = iter_precompute(range(5), n_precomputed=2)
>>> list(numbers)
[0, 1, 2, 3, 4]
>>> numbers = iter_precompute(range(5), n_precomputed=100)
>>> list(numbers)
[0, 1, 2, 3, 4]
>>> numbers = iter_precompute(range(5), n_precomputed=1)
>>> list(numbers)
[0, 1, 2, 3, 4]
moll.utils.iter_slices
iter_slices(
data: Iterable[D],
slice_size: int,
*,
collate_fn: Callable[[Iterable[D]], B] = list,
filter_fn: Callable[[D], bool] | None = None,
transform_fn: Callable[[B], T] | Literal["transpose"] | None = None
) -> Generator[OneOrMany[B], None, None]
Split an iterable into batches of a given size.
Examples:
>>> list(iter_slices(range(10), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(iter_slices([], 3))
[]
>>> list(iter_slices(range(4), 5, collate_fn=tuple))
[(0, 1, 2, 3)]
>>> list(iter_slices(range(10), 3, collate_fn=sum))
[3, 12, 21, 9]
Filter function can be applied before batching:
>>> list(iter_slices(range(10), 3, filter_fn=lambda n: n % 2 == 0))
[[0, 2, 4], [6, 8]]
If single data item has heterogeneous type, transform_fn="transpose"
can be used to
split it into batches of homogeneous type:
>>> data = [(1, "one"), (2, "two"), (3, "three"), (4, "four"), (5, "five")]
>>> for num, word in iter_slices(data, 2, transform_fn="transpose"):
... print(" plus ".join(word), "is", sum(num))
one plus two is 3
three plus four is 7
five is 5
When transform_fn="transpose"
, tuples of batches are yielded rather than a single
batch and collate_fn
is applied individually to each batch.
>>> list(iter_slices(data, 2, transform_fn="transpose", collate_fn=list))
[([1, 2], ['one', 'two']), ([3, 4], ['three', 'four']), ([5], ['five'])]
moll.utils.iter_transpose
iter_transpose(
data: Iterable[D], collate_fn: Callable[[Iterable[D]], R] = tuple
) -> Generator[R, None, None]
Transpose an iterable of iterables.
Examples:
>>> list(iter_transpose([[1, 2, 3], [4, 5, 6]]))
[(1, 4), (2, 5), (3, 6)]
>>> list(iter_transpose([[1, 2, 3], [4, 5, 6]], collate_fn=sum))
[5, 7, 9]
moll.utils.listify
listify(fn: Callable[..., Iterable]) -> Callable
Decorator to convert a generator function into a list-returning function.
Examples:
>>> @listify
... def numbers():
... yield from range(5)
>>> numbers()
[0, 1, 2, 3, 4]
>>> @listify
... def empty():
... if False:
... yield
>>> empty()
[]
moll.utils.map_concurrently
map_concurrently(
fn: Callable[[D], R],
data: Iterable[D],
*,
n_workers: int | None = None,
proc: bool = False,
exception_fn: Callable[[Exception], Any] | Literal["ignore", "raise"] | None = "raise",
buffer_size=150000
) -> Generator[R | None, None, None]
Apply a function to each item in an iterable in parallel.
Examples:
>>> def square(x):
... return x**2
>>> list(map_concurrently(square, range(10)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
By default, exceptions are raised:
>>> bad_fn = lambda x: (x - 2) / (x - 2) * x
>>> accumulator = []
>>> for result in map_concurrently(bad_fn, range(5)):
... accumulator.append(result)
Traceback (most recent call last):
...
ZeroDivisionError: division by zero
All computations before the exception are always returned:
>>> accumulator
[0.0, 1.0]
Exceptions can be easily ignored:
>>> list(map_concurrently(bad_fn, range(5), exception_fn="ignore"))
[0.0, 1.0, 3.0, 4.0]
If exception_fn=None
, None
is yielded instead of the result:
>>> list(map_concurrently(bad_fn, range(5), exception_fn=None))
[0.0, 1.0, None, 3.0, 4.0]
Exceptions can be handled in a custom way:
>>> def const(e):
... return 42
>>> list(map_concurrently(bad_fn, range(5), exception_fn=const))
[0.0, 1.0, 42, 3.0, 4.0]
By default, the number of workers is equal to the number of CPU cores,
multithreading is used. Set proc=True
to enable multiprocessing:
>>> from math import factorial
>>> list(map_concurrently(factorial, range(10, 15), proc=True))
[3628800, 39916800, 479001600, 6227020800, 87178291200]
moll.utils.matrix_cross_sum
matrix_cross_sum(X: Array, i: int, j: int, row_only=False, crossover=True)
Compute the sum of the elements in the row i
and the column j
of the matrix X
.
moll.utils.no_exceptions
no_exceptions(
fn: Callable, exceptions: OneOrMany[type[BaseException]] = Exception, default: Any = None
) -> Callable
Decorator to catch exceptions and return a default value instead.
Examples:
>>> @no_exceptions(default="Error occurred")
... def bad_fn(x):
... return x / 0
>>> bad_fn(10)
'Error occurred'
>>> @no_exceptions(exceptions=ZeroDivisionError)
... def bad_fn(x):
... return x / 0
>>> bad_fn(10)
>>> @no_exceptions(exceptions=TypeError)
... def bad_fn(x):
... return x / 0
>>> bad_fn(10)
Traceback (most recent call last):
...
ZeroDivisionError: division by zero
moll.utils.no_warnings
no_warnings(fn: Callable, suppress_rdkit=True) -> Callable
Decorator to suppress warnings in a function.
Examples:
>>> import warnings
>>> @no_warnings
... def warn():
... warnings.warn("Boooo!!!", UserWarning)
>>> warn()
>>> from rdkit import Chem
>>> @no_warnings
... def warn_rdkit():
... Chem.MolFromSmiles("C1=CC=CC=C1O")
>>> warn_rdkit()
moll.utils.partition
partition(data: list, *, n_partitions: int)
Partition the data into n_partitions
partitions.
moll.utils.points_around
points_around(
center: Array, n_points: int, std: float = 1, cap_radius: float | None = None, seed: Seed = None
)
Generate points around a center.
moll.utils.random_grid_points
random_grid_points(n_points: int, dim: int, n_ticks: int, spacing: int = 1, seed: Seed = None)
Generate random grid points.
moll.utils.time_int_seed
time_int_seed() -> int
Returns a number of microseconds since the epoch.
moll.utils.time_key
time_key() -> Array
Returns a JAX PRNG key based on the current time.
moll.utils.unpack_arguments
unpack_arguments(fn) -> Callable[..., Any]
Decorator to unpack arguments.
Examples:
>>> @unpack_arguments
... def add(x, y, z=0):
... return x + y + z
>>> add([1, 2])
3
>>> add((1, 2, 3))
6