-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: add ml/cluster/strided/dkmeansld
#9703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| /* eslint-disable valid-jsdoc */ | ||
| /** | ||
| * @license Apache-2.0 | ||
| * | ||
| * Copyright (c) 2026 The Stdlib Authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| 'use strict'; | ||
|
|
||
| // MODULES // | ||
|
|
||
| var dkmeanselk = require( '@stdlib/ml/cluster/strided/dkmeanselk' ); | ||
| var dkmeansld = require( '@stdlib/ml/cluster/strided/dkmeansld' ); | ||
| var setReadOnly = require( '@stdlib/utils/define-nonenumerable-read-only-property' ); | ||
| var isMatrixLike = require( '@stdlib/assert/is-matrix-like' ); | ||
| var isInteger = require( '@stdlib/assert/is-integer' ); | ||
| var format = require( '@stdlib/string/format' ); | ||
| var initCentroids = require( './init_centroids.js' ); | ||
|
Check failure on line 30 in lib/node_modules/@stdlib/ml/cluster/strided/dkmeansld/lib/high_level_kmeans.js
|
||
|
|
||
|
|
||
| // MAIN // | ||
|
|
||
| /** | ||
| * Kmeans clustering. | ||
| * | ||
| * @private | ||
| * @param {PositiveInteger} k - number of clusters | ||
| * @param {(string|ndarray)} init - initialization method or initial centroids | ||
| * @param {(PositiveInteger|string)} replicates - number of replicates or 'auto' | ||
| * @throws {TypeError} first argument must be a positive integer | ||
| * @throws {TypeError} second argument must be a valid initialization method or matrix | ||
| * @throws {TypeError} third argument must be a positive integer or 'auto' | ||
| * @returns {Function} fit function | ||
| * | ||
| * @example | ||
| * var Float64Array = require( '@stdlib/array/float64' ); | ||
| * var ndarray = require( '@stdlib/ndarray/ctor' ); | ||
| * var kmeans = require( '@stdlib/ml/cluster/strided/dkmeansld' ); | ||
| * | ||
| */ | ||
| function dkmeans( k, init, replicates, maxIter, tol, metric, algorithm ) { // This will live in ml/cluster/kmeans/ctor, kept here just for reference | ||
| // TODO: refactor functions arguments and include a `options` argument to follow same pattern as `ml/incr/kmeans` | ||
| var model; | ||
| var reps; | ||
|
|
||
| // TODO: validate function arguments | ||
|
|
||
| if ( replicates === 'auto' ) { | ||
| if ( init === 'kmeans++' || isMatrixLike( init ) ) { | ||
| reps = 1; | ||
| } else if ( init === 'random' ) { | ||
| // reps = ?? | ||
| } else if ( init === 'forgy' ) { | ||
| // reps = ?? | ||
| } | ||
| } else if ( isInteger( replicates ) ) { | ||
| reps = replicates; | ||
| } else { | ||
| throw new TypeError( format( 'invalid argument. Argument specifying method for initialization must be either `kmeans++`, `random`, `forgy` or matrix specifying initial centroids. Value: `%s`.', init ) ); | ||
| } | ||
|
|
||
| // TODO: update the below attachment to follow similar pattern to stats/strided/ztests | ||
|
Check warning on line 74 in lib/node_modules/@stdlib/ml/cluster/strided/dkmeansld/lib/high_level_kmeans.js
|
||
| setReadOnly( model, 'fit', fit ); | ||
|
|
||
| return model; | ||
|
|
||
| /** | ||
| * Computes fitted cluster results using kmeans clustering. | ||
| * | ||
| * @private | ||
| * @param {MatrixLike} X - input data matrix | ||
| * @throws {TypeError} first argument must be a matrix-like object | ||
| * @returns {Object} clustering results | ||
| * | ||
| * @example | ||
| * var Float64Array = require( '@stdlib/array/float64' ); | ||
| * var ndarray = require( '@stdlib/ndarray/ctor' ); | ||
| * | ||
| */ | ||
| function fit( X ) { | ||
| var kmeansSingle; | ||
| var centroids; | ||
| var singleOut; | ||
| var out; | ||
| var sx1; | ||
| var sx2; | ||
| var ox; | ||
| var M; | ||
| var N; | ||
| var i; | ||
|
|
||
| // TODO: Step 1 : validate input matrix | ||
|
|
||
| // TODO: Step 2 : define arguments | ||
| M = X.shape[ 0 ]; | ||
| N = X.shape[ 1 ]; | ||
| sx1 = X.stride[ 0 ]; | ||
| sx2 = X.stride[ 1 ]; | ||
| ox = X.offset; | ||
|
|
||
| /** | ||
| * NOTE : M should be greater than k (M > k) | ||
| * ref : https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L876 | ||
| */ | ||
|
|
||
| if ( algorithm === 'elkan' ) { | ||
| kmeansSingle = dkmeanselk; | ||
| } else if ( algorithm === 'lloyd' ) { | ||
| kmeansSingle = dkmeansld; | ||
| } | ||
|
|
||
| for ( i = 0; i < reps; i++ ) { | ||
| centroids = initCentroids( X, init, k ); // ref : https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L961 | ||
| singleOut = kmeansSingle( M, N, k, metric, maxIter, tol, X, sx1, sx2, ox, centroids, k, N, 0 ); // magic number `0` because we generate the centroid array with no offset | ||
|
Check failure on line 126 in lib/node_modules/@stdlib/ml/cluster/strided/dkmeansld/lib/high_level_kmeans.js
|
||
|
|
||
| /** | ||
| * According to sklearn, `singleOut` should be { labels, inertia, centers, nIter } | ||
| * ref: https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L1531 | ||
| * ??? How should we handle this ??? | ||
| */ | ||
| } | ||
|
|
||
| /** | ||
| * TODO : Check convergence issue | ||
| * ref : https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L1545 | ||
| */ | ||
|
|
||
| /** | ||
| * TODO : Build the `out` object | ||
| * ref : https://github.com/stdlib-js/stdlib/pull/9703#discussion_r2681280854 | ||
| */ | ||
|
|
||
| return out; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| // EXPORTS // | ||
|
|
||
| module.exports = dkmeans; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| /** | ||
| * @license Apache-2.0 | ||
| * | ||
| * Copyright (c) 2026 The Stdlib Authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| 'use strict'; | ||
|
|
||
| /** | ||
| * Compute fitted cluster results using Lloyd algorithm. | ||
| * | ||
| * @module @stdlib/ml/cluster/strided/dkmeansld | ||
| * | ||
| * @example | ||
| * var Float64Array = require( '@stdlib/array/float64' ); | ||
| * var ndarray = require( '@stdlib/ndarray/ctor' ); | ||
| * var kmeans = require( '@stdlib/ml/cluster/strided/dkmeansld' ); | ||
| * | ||
| */ | ||
|
|
||
| // MAIN // | ||
|
|
||
| var main = require( './main.js' ); | ||
|
|
||
|
|
||
| // EXPORTS // | ||
|
|
||
| module.exports = main; |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| /** | ||
| * @license Apache-2.0 | ||
| * | ||
| * Copyright (c) 2026 The Stdlib Authors. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| 'use strict'; | ||
|
|
||
| // MODULES // | ||
|
|
||
| var dlacpy = require( '@stdlib/lapack/base/dlacpy' ).ndarray; | ||
| var Float64Array = require( '@stdlib/array/float64' ); | ||
| var Int32Array = require( '@stdlib/array/int32' ); | ||
| var dfill = require( '@stdlib/blas/ext/base/dfill' ); | ||
| var isEqualArray = require( '@stdlib/assert/is-equal-array' ); | ||
| var dcopy = require( '@stdlib/blas/base/dcopy' ).ndarray; | ||
| var deuclidean = require( '@stdlib/stats/strided/distances/deuclidean' ).ndarray; | ||
| var dcosine = require( '@stdlib/stats/strided/distances/deuclidean' ).ndarray; | ||
| var dcityblock = require( '@stdlib/stats/strided/distances/deuclidean' ).ndarray; | ||
|
|
||
|
|
||
| // MAIN // | ||
|
|
||
| /** | ||
| * Compute fitted cluster results using Lloyd algorithm. | ||
| * @param {PositiveInteger} M - number of samples | ||
| * @param {PositiveInteger} N - number of features | ||
| * @param {PositiveInteger} k - number of clusters | ||
| * @param {NonNegativeInteger} replicates - number of times to repeat clustering with different centroids | ||
| * @param {String} metric - distance metric | ||
| * @param {NonNegativeInteger} maxIter - maximum number of iterations. | ||
| * @param {integer} tol - relative tolerance before declaring convergence. | ||
| * @param {Float64Array} X - input strided matrix | ||
| * @param {integer} strideX1 - stride of the first dimension. | ||
| * @param {integer} strideX2 - stride of the second dimension. | ||
| * @param {integer} offsetX - starting index. | ||
| * @param {Float64Array} init - strided array containing initial centroid locations. | ||
| * @param {integer} strideInit1 - stride of first dimension. | ||
| * @param {integer} strideInit2 - stride of second dimension. | ||
| * @param {integer} strideInit3 - stride of the third dimension. | ||
| * @param {integer} offsetInit - initial index. | ||
| * @ returns {Result} results object | ||
| */ | ||
| function dkmeansld( M, N, k, replicates, metric, maxIter, tol, X, strideX1, strideX2, offsetX, init, strideInit1, strideInit2, strideInit3, offsetInit ) { // eslint-disable-line max-len | ||
| var centroidShift; | ||
| var centroidsNew; | ||
| var strictConv; | ||
| var labelsOld; | ||
| var centroids; | ||
| var bestDist; | ||
| var inertia; | ||
| var labels; | ||
| var counts; | ||
| var shift; | ||
| var same; | ||
| var dist; | ||
| var best; | ||
| var iter; | ||
| var out; | ||
| var ox; | ||
| var i; | ||
| var j; | ||
| var c; | ||
| var d; | ||
|
|
||
| centroids = new Float64Array( k*N ); | ||
| centroidsNew = new Float64Array( k*N ); | ||
| labels = new Int32Array( M ); | ||
| labelsOld = new Int32Array( M ); | ||
| counts = new Int32Array( k ); // q: sklearn supports sample_weights, should we do the same? if yes, change it to Float64Array | ||
|
|
||
| // centroidShift = new Float64Array( k ); | ||
|
|
||
| dlacpy( 'all', k, N, init, strideInit2, strideInit3, offsetInit, centroids, strideInit2, strideInit3, 0 ); | ||
|
|
||
| if ( metric === 'euclidean' ) { | ||
| dist = deuclidean; // TODO: change it to dsquared-euclidean once implemented | ||
| } else if ( metric === 'cosine' ) { | ||
| dist = dcosine; // TODO: change it to dsquared-cosine once implemented | ||
| } else if ( metric === 'cityblock' ) { | ||
| dist = dcityblock; | ||
| } | ||
|
|
||
| // this is a dense implementation, sklearn also has a sparse implementation | ||
| // https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L696C1-L700C46 | ||
| strictConv = false; | ||
| for ( iter = 0; iter < maxIter; iter++ ) { | ||
| dfill( k*N, 0.0, centroidsNew, 1 ); | ||
| dfill( k, 0, counts, 1 ); // How do I fill it with a int32? | ||
|
|
||
| ox = offsetX; | ||
| for ( i = 0; i < M; i++ ) { | ||
| best = 0; | ||
| bestDist = dist( N, X, strideX2, ox, centroids, 1, 0 ); | ||
| for ( c = 1; c < k; c++ ) { | ||
| d = dist( N, X, strideX2, ox, centroids, 1, c*N ); | ||
| if ( d < bestDist ) { | ||
| bestDist = d; | ||
| best = c; | ||
| } | ||
| } | ||
|
|
||
| labels[ i ] = best; | ||
| counts[ best ] += 1; | ||
| ox += strideX1; | ||
|
|
||
| for ( j = 0; j < N; j++ ) { | ||
| centroidsNew[ ( best*N )+j ] += X[ offsetX + (i*strideX1) + (j*strideX2) ]; // eslint-disable-line max-len | ||
| } | ||
| } | ||
|
|
||
| for ( c = 0; c < k; c++ ) { | ||
| if ( counts[ c ] > 0 ) { | ||
| for ( j = 0; j < N; j++ ) { | ||
| centroidsNew[ ( c*N )+j ] /= counts[ c ]; | ||
| } | ||
| } else { | ||
| for ( j = 0; j < N; j++ ) { | ||
| centroidsNew[ ( c*N )+j ] = centroids[ ( c*N )+j ]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| d = centroidsNew[ 0 ] - centroids[ 0 ]; | ||
| shift = d * d; | ||
| for ( i = 1; i < k * N; i++ ) { | ||
| d = centroidsNew[ i ] - centroids[ i ]; | ||
| shift += d * d; | ||
| centroids[ i ] = centroidsNew[ i ]; | ||
| } | ||
|
|
||
| if ( isEqualArray( labels, labelsOld ) ) { | ||
| strictConv = true; | ||
| break; | ||
| } else { | ||
| // TODO: implement center shift | ||
| } | ||
| dcopy( M, labels, 1, 0, labelsOld, 1, 0 ); // Magic number `1` and `0` because we assume labels are stored contiguously | ||
| } | ||
|
|
||
| if (!strictConv) { | ||
| // TODO: Rerun the E-step | ||
| } | ||
|
|
||
| // TODO: Compute intertia | ||
|
|
||
| return out; // TODO: create a results object similar to stats/base/ztest/two-sample/results/factory | ||
| } | ||
|
|
||
|
|
||
| // EXPORTS // | ||
|
|
||
| module.exports = dkmeansld; | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized after looking at this and the internal array allocation that, similar to https://github.com/stdlib-js/stdlib/tree/develop/lib/node_modules/%40stdlib/stats/strided/dztest, we'll want an
outparameter for a pre-allocated results object. In the C implementation, we don't want to be dynamically allocating memory.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also realizing that we'll want a
samples(number of samples, M) and afeatures(number of features, N) property on the results object. Having these values will allow consumers to properly read theinit,centroids, andstatisticslinear memory buffers, independent of the original function invocation.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current thinking is the following signature:
where
centroids,statistics, andlabelsare output arrays andoutis a "results" object with the following fields:Similar to
ztest, several of the parameter values should be copied over to what will be an uninitialized results object, such as M, N, k.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signature with abbreviated parameter names: