Source code for easycv.utils.alias_multinomial

# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch


[docs]class AliasMethod(object): ''' https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ '''
[docs] def __init__(self, probs): if probs.sum() > 1: probs.div_(probs.sum()) K = len(probs) self.prob = torch.zeros(K) self.alias = torch.LongTensor([0] * K) # Sort the data into the outcomes with probabilities # that are larger and smaller than 1/K. smaller = [] larger = [] for kk, prob in enumerate(probs): self.prob[kk] = K * prob if self.prob[kk] < 1.0: smaller.append(kk) else: larger.append(kk) # Loop though and create little binary mixtures that # appropriately allocate the larger outcomes over the # overall uniform mixture. while len(smaller) > 0 and len(larger) > 0: small = smaller.pop() large = larger.pop() self.alias[small] = large self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] if self.prob[large] < 1.0: smaller.append(large) else: larger.append(large) for last_one in smaller + larger: self.prob[last_one] = 1
[docs] def cuda(self): self.prob = self.prob.cuda() self.alias = self.alias.cuda()
[docs] def draw(self, N): ''' Draw N samples from multinomial ''' K = self.alias.size(0) kk = torch.zeros( N, dtype=torch.long, device=self.prob.device).random_(0, K) prob = self.prob.index_select(0, kk) alias = self.alias.index_select(0, kk) # b is whether a random number is greater than q b = torch.bernoulli(prob) oq = kk.mul(b.long()) oj = alias.mul((1 - b).long()) return oq + oj