/*
WRS.c
(c) AShelly (gist.github.com/ashelly)
Weighted random sampling with replacement of N items in O(1) time.
(After preparing a O(N) sized buffer in O(NlgN) time.)
The concept is:
Randomly select a buffer index. Each index is selected with probablilty 1/N.
Each index stores the fraction of hits for which this item should be selected,
and the index of another item, which will be selected if this one is not.
*/
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
//helpers. !!for demo only, Not ideal uniform distribution,
double rand_percent() {
return ((double)rand())/RAND_MAX
; }
int rand_in_range(int first, int last){
return (int)(rand_percent()*(last-first));
}
//data structure
typedef struct wrs_data {
double share;
int pair;
int idx;
} wrs_t;
//sort helper
int wrs_sharecmp(const void* a, const void* b) {
double delta = ((wrs_t*)a)->share - ((wrs_t*)b)->share;
return (delta<0) ? -1 : (delta>0);
}
//Initialize the data structure
wrs_t* wrs_create(int* weights, size_t N) {
wrs_t
* data
= malloc(sizeof(wrs_t
)); double sum = 0;
int i;
for (i=0;i<N;i++) { sum+=weights[i]; }
for (i=0;i<N;i++) {
//what percent of the ideal distribution is in this bucket?
data[i].share = weights[i]/(sum/N);
data[i].pair = N;
data[i].idx = i;
}
//sort ascending by size
qsort(data
,N
, sizeof(wrs_t
),wrs_sharecmp
);
int j=N-1; //the biggest bucket
for (i=0;i<j;i++) {
int check = i;
double excess = 1.0 - data[check].share;
while (excess>0 && i<j) {
//If this bucket has less samples than a flat distribution,
//it will be hit more frequently than it should be.
//So send excess hits to a bucket which has too many samples.
data[check].pair=j;
// Account for the fact that the paired bucket will be hit more often,
data[j].share -= excess;
excess = 1.0 - data[j].share;
// If pair bucket now has excess hits, send some to new largest bucket at j-1
if (excess >= 0) { check=j--;}
}
}
return data;
}
int wrs_pick(wrs_t* collection, size_t N)
//O(1) random sampling with weights. (after preparing the collection in O(NlgN))
//randomly select a bucket, and a percentage
//if the percentage is greater than that bucket's share of hits, use it's paired bucket.
{
int idx = rand_in_range(0,N);
double pct = rand_percent();
if (pct > collection[idx].share) { idx = collection[idx].pair; }
return collection[idx].idx;
}
// sample usage
int weights[]= {0,1,4,10,15,30,16,10,8,6};
#define NW (sizeof(weights)/sizeof(weights[0]))
int main(int argc,char*argv[]){
//sample data
wrs_t* collection = wrs_create(weights, NW);
//build new histogram
int i,samples[NW]={0};
for (i=0; i<100; i++){
samples[wrs_pick(collection, NW)]++;
}
//check shape matches
for (i
=0;i
<NW
;i
++) { printf("%d ",samples
[i
]);} return 0;
}
LyoKICBXUlMuYwogIChjKSBBU2hlbGx5IChnaXN0LmdpdGh1Yi5jb20vYXNoZWxseSkKICBXZWlnaHRlZCByYW5kb20gc2FtcGxpbmcgd2l0aCByZXBsYWNlbWVudCBvZiBOIGl0ZW1zICBpbiBPKDEpIHRpbWUuCiAgKEFmdGVyIHByZXBhcmluZyBhIE8oTikgc2l6ZWQgYnVmZmVyIGluIE8oTmxnTikgdGltZS4pCiAgCiAgVGhlIGNvbmNlcHQgaXM6IAogICAgUmFuZG9tbHkgc2VsZWN0IGEgYnVmZmVyIGluZGV4LiAgRWFjaCBpbmRleCBpcyBzZWxlY3RlZCB3aXRoIHByb2JhYmxpbHR5IDEvTi4KICAgIEVhY2ggaW5kZXggc3RvcmVzIHRoZSBmcmFjdGlvbiBvZiBoaXRzIGZvciB3aGljaCB0aGlzIGl0ZW0gc2hvdWxkIGJlIHNlbGVjdGVkLAogICAgYW5kIHRoZSBpbmRleCBvZiBhbm90aGVyIGl0ZW0sIHdoaWNoIHdpbGwgYmUgc2VsZWN0ZWQgaWYgdGhpcyBvbmUgaXMgbm90LiAKICovCgojaW5jbHVkZSA8c3RkaW8uaD4KI2luY2x1ZGUgPHN0ZGxpYi5oPgojaW5jbHVkZSA8YXNzZXJ0Lmg+CgoKLy9oZWxwZXJzLiAgISFmb3IgZGVtbyBvbmx5LCBOb3QgaWRlYWwgdW5pZm9ybSBkaXN0cmlidXRpb24sIApkb3VibGUgcmFuZF9wZXJjZW50KCkgewogIHJldHVybiAoKGRvdWJsZSlyYW5kKCkpL1JBTkRfTUFYOwp9CgppbnQgcmFuZF9pbl9yYW5nZShpbnQgZmlyc3QsIGludCBsYXN0KXsKICByZXR1cm4gKGludCkocmFuZF9wZXJjZW50KCkqKGxhc3QtZmlyc3QpKTsKfQoKCi8vZGF0YSBzdHJ1Y3R1cmUKdHlwZWRlZiBzdHJ1Y3Qgd3JzX2RhdGEgewogIGRvdWJsZSBzaGFyZTsgCiAgaW50IHBhaXI7CiAgaW50IGlkeDsKfSB3cnNfdDsKCgovL3NvcnQgaGVscGVyCmludCB3cnNfc2hhcmVjbXAoY29uc3Qgdm9pZCogYSwgY29uc3Qgdm9pZCogYikgewogIGRvdWJsZSBkZWx0YSA9ICgod3JzX3QqKWEpLT5zaGFyZSAtICgod3JzX3QqKWIpLT5zaGFyZTsKICByZXR1cm4gKGRlbHRhPDApID8gLTEgOiAoZGVsdGE+MCk7Cn0KCgovL0luaXRpYWxpemUgdGhlIGRhdGEgc3RydWN0dXJlCndyc190KiB3cnNfY3JlYXRlKGludCogd2VpZ2h0cywgc2l6ZV90IE4pIHsKICB3cnNfdCogZGF0YSA9IG1hbGxvYyhzaXplb2Yod3JzX3QpKTsKICBkb3VibGUgc3VtID0gMDsKICBpbnQgaTsKICBmb3IgKGk9MDtpPE47aSsrKSB7IHN1bSs9d2VpZ2h0c1tpXTsgfQogIGZvciAoaT0wO2k8TjtpKyspIHsKICAgIC8vd2hhdCBwZXJjZW50IG9mIHRoZSBpZGVhbCBkaXN0cmlidXRpb24gaXMgaW4gdGhpcyBidWNrZXQ/CiAgICBkYXRhW2ldLnNoYXJlID0gd2VpZ2h0c1tpXS8oc3VtL04pOyAKICAgIGRhdGFbaV0ucGFpciA9IE47CiAgICBkYXRhW2ldLmlkeCA9IGk7CiAgfQogIC8vc29ydCBhc2NlbmRpbmcgYnkgc2l6ZQogIHFzb3J0KGRhdGEsTiwgc2l6ZW9mKHdyc190KSx3cnNfc2hhcmVjbXApOwoKICBpbnQgaj1OLTE7IC8vdGhlIGJpZ2dlc3QgYnVja2V0CiAgZm9yIChpPTA7aTxqO2krKykgewogICAgaW50IGNoZWNrID0gaTsKICAgIGRvdWJsZSBleGNlc3MgPSAxLjAgLSBkYXRhW2NoZWNrXS5zaGFyZTsKICAgIHdoaWxlIChleGNlc3M+MCAmJiBpPGopIHsKICAgICAgLy9JZiB0aGlzIGJ1Y2tldCBoYXMgbGVzcyBzYW1wbGVzIHRoYW4gYSBmbGF0IGRpc3RyaWJ1dGlvbiwKICAgICAgLy9pdCB3aWxsIGJlIGhpdCBtb3JlIGZyZXF1ZW50bHkgdGhhbiBpdCBzaG91bGQgYmUuICAKICAgICAgLy9TbyBzZW5kIGV4Y2VzcyBoaXRzIHRvIGEgYnVja2V0IHdoaWNoIGhhcyB0b28gbWFueSBzYW1wbGVzLgogICAgICBkYXRhW2NoZWNrXS5wYWlyPWo7IAogICAgICAvLyBBY2NvdW50IGZvciB0aGUgZmFjdCB0aGF0IHRoZSBwYWlyZWQgYnVja2V0IHdpbGwgYmUgaGl0IG1vcmUgb2Z0ZW4sCiAgICAgIGRhdGFbal0uc2hhcmUgLT0gZXhjZXNzOyAgCiAgICAgIGV4Y2VzcyA9IDEuMCAtIGRhdGFbal0uc2hhcmU7CiAgICAgIC8vIElmIHBhaXIgYnVja2V0IG5vdyBoYXMgZXhjZXNzIGhpdHMsIHNlbmQgc29tZSB0byBuZXcgbGFyZ2VzdCBidWNrZXQgYXQgai0xCiAgICAgIGlmIChleGNlc3MgPj0gMCkgeyBjaGVjaz1qLS07fSAKICAgIH0KICB9CiAgcmV0dXJuIGRhdGE7Cn0KCgppbnQgd3JzX3BpY2sod3JzX3QqIGNvbGxlY3Rpb24sIHNpemVfdCBOKQovL08oMSkgcmFuZG9tIHNhbXBsaW5nIHdpdGggd2VpZ2h0cy4gKGFmdGVyIHByZXBhcmluZyB0aGUgY29sbGVjdGlvbiBpbiBPKE5sZ04pKQovL3JhbmRvbWx5IHNlbGVjdCBhIGJ1Y2tldCwgYW5kIGEgcGVyY2VudGFnZQovL2lmIHRoZSBwZXJjZW50YWdlIGlzIGdyZWF0ZXIgdGhhbiB0aGF0IGJ1Y2tldCdzIHNoYXJlIG9mIGhpdHMsIHVzZSBpdCdzIHBhaXJlZCBidWNrZXQuCnsKICBpbnQgaWR4ID0gcmFuZF9pbl9yYW5nZSgwLE4pOwogIGRvdWJsZSBwY3QgPSByYW5kX3BlcmNlbnQoKTsKICBpZiAocGN0ID4gY29sbGVjdGlvbltpZHhdLnNoYXJlKSB7IGlkeCA9IGNvbGxlY3Rpb25baWR4XS5wYWlyOyB9CiAgcmV0dXJuIGNvbGxlY3Rpb25baWR4XS5pZHg7Cn0gCgovLyBzYW1wbGUgdXNhZ2UKaW50IHdlaWdodHNbXT0gezAsMSw0LDEwLDE1LDMwLDE2LDEwLDgsNn07CiNkZWZpbmUgTlcgKHNpemVvZih3ZWlnaHRzKS9zaXplb2Yod2VpZ2h0c1swXSkpCgppbnQgbWFpbihpbnQgYXJnYyxjaGFyKmFyZ3ZbXSl7CgogIC8vc2FtcGxlIGRhdGEKCiAgd3JzX3QqIGNvbGxlY3Rpb24gPSB3cnNfY3JlYXRlKHdlaWdodHMsIE5XKTsKCiAgLy9idWlsZCBuZXcgaGlzdG9ncmFtIAogIGludCBpLHNhbXBsZXNbTlddPXswfTsKICBmb3IgKGk9MDsgaTwxMDA7IGkrKyl7CiAgICBzYW1wbGVzW3dyc19waWNrKGNvbGxlY3Rpb24sIE5XKV0rKzsKICB9CgogIC8vY2hlY2sgc2hhcGUgbWF0Y2hlcwogIGZvciAoaT0wO2k8Tlc7aSsrKSB7IHByaW50ZigiJWQgIixzYW1wbGVzW2ldKTt9CiAgcHJpbnRmKCJcbiIpOwogIHJldHVybiAwOwp9