Source code for dipy.stats.sketching

import os
import tempfile

import numpy as np


[docs] def count_sketch(matrixa_name, matrixa_dtype, matrixa_shape, sketch_rows, tmp_dir): """Count Sketching algorithm to reduce the size of the matrix. Parameters ---------- matrixa_name : str The name of the memmap file containing the matrix A. matrixa_dtype : dtype The dtype of the matrix A. matrixa_shape : tuple The shape of the matrix A. sketch_rows : int The number of rows in the sketch matrix. tmp_dir : str The directory to save the temporary files. Returns ------- matrixc_file.name : str The name of the memmap file containing the sketch matrix. matrixc.dtype : dtype The dtype of the sketch matrix. matrixc.shape : tuple The shape of the sketch matrix. """ matrixa = np.squeeze( np.memmap(matrixa_name, dtype=matrixa_dtype, mode="r+", shape=matrixa_shape) ).reshape(np.prod(matrixa_shape[:-1]), matrixa_shape[-1]) with tempfile.NamedTemporaryFile( delete=False, dir=tmp_dir, suffix="matrix_t" ) as matrixt_file: matrixt = np.memmap( matrixt_file.name, dtype=matrixa_dtype, mode="w+", shape=matrixa.shape ) hashed_indices = np.random.choice(sketch_rows, matrixa.shape[0], replace=True) rand_signs = np.random.choice(2, matrixa.shape[0], replace=True) * 2 - 1 for i in range(0, matrixa.shape[0], matrixa.shape[0] // 20): end_index = min(i + matrixa.shape[0] // 20, matrixa.shape[0]) matrixt[i:end_index, :] = ( matrixa[i:end_index, :] * rand_signs[i:end_index, np.newaxis] ) with tempfile.NamedTemporaryFile( delete=False, dir=tmp_dir, suffix="matrix_C" ) as matrixc_file: matrixc = np.memmap( matrixc_file.name, dtype=matrixa_dtype, mode="w+", shape=(sketch_rows, matrixa.shape[1]), ) np.add.at(matrixc, hashed_indices, matrixt) matrixc.flush() matrixt.flush() del matrixt os.unlink(matrixt_file.name) return matrixc_file.name, matrixc.dtype, matrixc.shape