Source code for dipy.stats.sketching
import os
import tempfile
import numpy as np
[docs]
def count_sketch(matrixa_name, matrixa_dtype, matrixa_shape, sketch_rows, tmp_dir):
    """Count Sketching algorithm to reduce the size of the matrix.
    Parameters
    ----------
    matrixa_name : str
        The name of the memmap file containing the matrix A.
    matrixa_dtype : dtype
        The dtype of the matrix A.
    matrixa_shape : tuple
        The shape of the matrix A.
    sketch_rows : int
        The number of rows in the sketch matrix.
    tmp_dir : str
        The directory to save the temporary files.
    Returns
    -------
    matrixc_file.name : str
        The name of the memmap file containing the sketch matrix.
    matrixc.dtype : dtype
        The dtype of the sketch matrix.
    matrixc.shape : tuple
        The shape of the sketch matrix.
    """
    matrixa = np.squeeze(
        np.memmap(matrixa_name, dtype=matrixa_dtype, mode="r+", shape=matrixa_shape)
    ).reshape(np.prod(matrixa_shape[:-1]), matrixa_shape[-1])
    with tempfile.NamedTemporaryFile(
        delete=False, dir=tmp_dir, suffix="matrix_t"
    ) as matrixt_file:
        matrixt = np.memmap(
            matrixt_file.name, dtype=matrixa_dtype, mode="w+", shape=matrixa.shape
        )
        hashed_indices = np.random.choice(sketch_rows, matrixa.shape[0], replace=True)
        rand_signs = np.random.choice(2, matrixa.shape[0], replace=True) * 2 - 1
        for i in range(0, matrixa.shape[0], matrixa.shape[0] // 20):
            end_index = min(i + matrixa.shape[0] // 20, matrixa.shape[0])
            matrixt[i:end_index, :] = (
                matrixa[i:end_index, :] * rand_signs[i:end_index, np.newaxis]
            )
    with tempfile.NamedTemporaryFile(
        delete=False, dir=tmp_dir, suffix="matrix_C"
    ) as matrixc_file:
        matrixc = np.memmap(
            matrixc_file.name,
            dtype=matrixa_dtype,
            mode="w+",
            shape=(sketch_rows, matrixa.shape[1]),
        )
        np.add.at(matrixc, hashed_indices, matrixt)
        matrixc.flush()
        matrixt.flush()
    del matrixt
    os.unlink(matrixt_file.name)
    return matrixc_file.name, matrixc.dtype, matrixc.shape