Rust で重み付き抽選

重み付き抽選とは、ゲームのアイテムドロップの判定などに使われる選択手法です。

それぞれの選択率がリンゴ (50%)、みかん (20%)、柿 (20%)の場合、この確率に従って乱数で選びます。

Rust では rand クレートの WeightedIndex や rand_distr クレートの WeightedAliasIndex を利用できます。

rand::distributions::WeightedIndex

選択時の計算量は O(log N) です。N は要素数。つまり、要素数が多いと急に遅くなります。初期化時の計算量は O(N) です。

一度のサンプリングで乱数を1つ消費します。

コードより要素数とメモリ使用量の関係を調べます。

// https://docs.rs/rand/0.8.4/src/rand/distributions/weighted_index.rs.html#81-85

pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
    cumulative_weights: Vec<X>,
    total_weight: X,
    weight_distribution: X::Sampler,
}

X は重みを指定する値の型です。
cumulative_weights のサイズは要素数と同じです。つまり、メモリ量は O(N) です。

rand_distr::weighted_alias::WeightedAliasIndex

選択時の計算量は O(1) です。初期化時の計算量は O(N) です。要素数が多いと事前に用意するデータが大きくなりメモリ消費量が増えます。

一度のサンプリングで乱数を2つ消費します。

同様にコードを確認します。

// https://docs.rs/rand_distr/0.4.2/src/rand_distr/weighted_alias.rs.html#72-77

pub struct WeightedAliasIndex<W: AliasableWeight> {
    aliases: Box<[u32]>,
    no_alias_odds: Box<[W]>,
    uniform_index: Uniform<u32>,
    uniform_within_weight_sum: Uniform<W>,
}

W は重みを指定する数値の型です。
no_alias_odds の要素数は N です。
aliases の要素数は N なので(ドキュメントより)、WeightedIndex とのメモリ量の差は sizeof(u32) * N です。

乱数の使用数調査

コードを見ると WeightedAliasIndex は一度のサンプリングで2つの乱数を消費しているようなので、以下のようにして乱数の消費数を調べます。

use rand::Rng;
use rand::distributions::{Distribution, WeightedIndex};
use rand_core::{Error, RngCore, SeedableRng};
use rand_pcg::Lcg64Xsh32;
use rand_distr::weighted_alias::WeightedAliasIndex;

/// Couting wrapper.
struct CountRnd {
    rng: Lcg64Xsh32,
    count_u32: usize,
}

impl CountRnd {
    pub fn new(seed: u64) -> Self {
        Self {
            rng: Lcg64Xsh32::seed_from_u64(seed),
            count_u32: 0,
        }
    }

    pub fn count_u32(&self) -> usize {
        self.count_u32
    }

    pub fn reset_count(&mut self) {
        self.count_u32 = 0;
    }
}

impl RngCore for CountRnd {
    fn next_u32(&mut self) -> u32 {
        self.count_u32 += 1;
        self.rng.next_u32()
    }

    fn next_u64(&mut self) -> u64 {
        self.rng.next_u64()
    }

    fn fill_bytes(&mut self, dest: &mut [u8]) {
        self.rng.fill_bytes(dest)
    }

    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
        self.rng.try_fill_bytes(dest)
    }
}


fn main() {
    let mut rng = CountRnd::new(1234);

    let items = ['a', 'b', 'c'];
    let weights = [20, 10, 10];

    let dist = WeightedIndex::new(&weights.clone()).unwrap();
    for _ in 0..100 {
        let _i = items[dist.sample(&mut rng)];
    }
    println!("count: {}", rng.count_u32()); // 100
    rng.reset_count();


    let dist = WeightedAliasIndex::new(weights.to_vec()).unwrap();
    for _ in 0..100 {
        let _i = items[dist.sample(&mut rng)];
    }
    println!("count: {}", rng.count_u32()); // 200
}

二倍消費していることが分かります。

Rust では乱数を使用するときによく使われるクレートに重み付き抽選が実装されています。
抽選対象の確率から分布が作成され、その分布に与えられた乱数から抽選されます。
RngCore トレイトを実装している乱数生成器ならどれでも使用できます。

高速に判定したい場合はメモリの使用量に問題なければ WeightedAliasIndex を使えます。
しかし、一度のサンプリングで乱数を2つ消費したくなければ WeightedIndex を使いましょう。