JSCC / Transformer / channel.py
channel.py
Raw
import torch
import numpy as np

def erasure_channel(encoded_msg : torch.Tensor):
    batch, seq_len, embed_len = encoded_msg.shape
    rng = np.random.default_rng()
    erasures = (rng.random(seq_len)).astype(int)
    encoded_msg[:][erasures] = torch.zeros(embed_len)
    return encoded_msg, erasures