import torch import numpy as np def erasure_channel(encoded_msg : torch.Tensor): batch, seq_len, embed_len = encoded_msg.shape rng = np.random.default_rng() erasures = (rng.random(seq_len)).astype(int) encoded_msg[:][erasures] = torch.zeros(embed_len) return encoded_msg, erasures