bagua.torch_api.model_parallel.moe.sharded_moe¶
Module Contents¶
- class bagua.torch_api.model_parallel.moe.sharded_moe.MOELayer(gate, experts, num_local_experts, group=None)¶
Bases:
Base
MOELayer module which implements MixtureOfExperts as described in Gshard_.
gate = TopKGate(model_dim, num_experts) moe = MOELayer(gate, expert) output = moe(input) l_aux = moe.l_aux
- Parameters:
gate (torch.nn.Module) – gate network
expert (torch.nn.Module) – expert network
experts (torch.nn.Module) –
num_local_experts (int) –
group (Optional[Any]) –
- forward(*input, **kwargs)¶
- Parameters:
input (torch.Tensor) –
kwargs (Any) –
- Return type:
torch.Tensor
- class bagua.torch_api.model_parallel.moe.sharded_moe.TopKGate(model_dim, num_experts, k=1, capacity_factor=1.0, eval_capacity_factor=1.0, min_capacity=4, noisy_gate_policy=None)¶
Bases:
torch.nn.Module
Gate module which implements Top2Gating as described in Gshard_.
gate = TopKGate(model_dim, num_experts) l_aux, combine_weights, dispatch_mask = gate(input)
- Parameters:
model_dim (int) – size of model embedding dimension
num_experts (ints) – number of experts in model
k (int) –
capacity_factor (float) –
eval_capacity_factor (float) –
min_capacity (int) –
noisy_gate_policy (Optional[str]) –
- wg :torch.nn.Linear¶
- forward(input, used_token=None)¶
- Parameters:
input (torch.Tensor) –
used_token (torch.Tensor) –
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- bagua.torch_api.model_parallel.moe.sharded_moe.gumbel_rsample(shape, device)¶
- Parameters:
shape (Tuple) –
device (torch.device) –
- Return type:
torch.Tensor
- bagua.torch_api.model_parallel.moe.sharded_moe.multiplicative_jitter(x, device, epsilon=0.01)¶
Modified from swtich transformer paper. mesh transformers Multiply values by a random number between 1-epsilon and 1+epsilon. Makes models more resilient to rounding errors introduced by bfloat16. This seems particularly important for logits. :param x: a torch.tensor :param device: torch.device :param epsilon: a floating point value
- Returns:
a jittered x.
- Parameters:
device (torch.device) –
- bagua.torch_api.model_parallel.moe.sharded_moe.top1gating(logits, capacity_factor, min_capacity, used_token=None, noisy_gate_policy=None)¶
Implements Top1Gating on logits.
- Parameters:
logits (torch.Tensor) –
capacity_factor (float) –
min_capacity (int) –
used_token (torch.Tensor) –
noisy_gate_policy (Optional[str]) –
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- bagua.torch_api.model_parallel.moe.sharded_moe.top2gating(logits, capacity_factor)¶
Implements Top2Gating on logits.
- Parameters:
logits (torch.Tensor) –
capacity_factor (float) –
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- bagua.torch_api.model_parallel.moe.sharded_moe.Base¶
- bagua.torch_api.model_parallel.moe.sharded_moe.exp_selection_uniform_map :Dict[torch.device, Callable]¶
- bagua.torch_api.model_parallel.moe.sharded_moe.gumbel_map :Dict[torch.device, Callable]¶
- bagua.torch_api.model_parallel.moe.sharded_moe.uniform_map :Dict[torch.device, Callable]¶