bagua.torch_api.model_parallel.moe.sharded_moe

Module Contents

class bagua.torch_api.model_parallel.moe.sharded_moe.MOELayer(gate, experts, num_local_experts, group=None)

Bases: Base

MOELayer module which implements MixtureOfExperts as described in Gshard_.

gate = TopKGate(model_dim, num_experts)
moe = MOELayer(gate, expert)
output = moe(input)
l_aux = moe.l_aux
Parameters:
  • gate (torch.nn.Module) – gate network

  • expert (torch.nn.Module) – expert network

  • experts (torch.nn.Module) –

  • num_local_experts (int) –

  • group (Optional[Any]) –

forward(*input, **kwargs)
Parameters:
  • input (torch.Tensor) –

  • kwargs (Any) –

Return type:

torch.Tensor

class bagua.torch_api.model_parallel.moe.sharded_moe.TopKGate(model_dim, num_experts, k=1, capacity_factor=1.0, eval_capacity_factor=1.0, min_capacity=4, noisy_gate_policy=None)

Bases: torch.nn.Module

Gate module which implements Top2Gating as described in Gshard_.

gate = TopKGate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)
Parameters:
  • model_dim (int) – size of model embedding dimension

  • num_experts (ints) – number of experts in model

  • k (int) –

  • capacity_factor (float) –

  • eval_capacity_factor (float) –

  • min_capacity (int) –

  • noisy_gate_policy (Optional[str]) –

wg :torch.nn.Linear
forward(input, used_token=None)
Parameters:
  • input (torch.Tensor) –

  • used_token (torch.Tensor) –

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

bagua.torch_api.model_parallel.moe.sharded_moe.gumbel_rsample(shape, device)
Parameters:
  • shape (Tuple) –

  • device (torch.device) –

Return type:

torch.Tensor

bagua.torch_api.model_parallel.moe.sharded_moe.multiplicative_jitter(x, device, epsilon=0.01)

Modified from swtich transformer paper. mesh transformers Multiply values by a random number between 1-epsilon and 1+epsilon. Makes models more resilient to rounding errors introduced by bfloat16. This seems particularly important for logits. :param x: a torch.tensor :param device: torch.device :param epsilon: a floating point value

Returns:

a jittered x.

Parameters:

device (torch.device) –

bagua.torch_api.model_parallel.moe.sharded_moe.top1gating(logits, capacity_factor, min_capacity, used_token=None, noisy_gate_policy=None)

Implements Top1Gating on logits.

Parameters:
  • logits (torch.Tensor) –

  • capacity_factor (float) –

  • min_capacity (int) –

  • used_token (torch.Tensor) –

  • noisy_gate_policy (Optional[str]) –

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

bagua.torch_api.model_parallel.moe.sharded_moe.top2gating(logits, capacity_factor)

Implements Top2Gating on logits.

Parameters:
  • logits (torch.Tensor) –

  • capacity_factor (float) –

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

bagua.torch_api.model_parallel.moe.sharded_moe.Base
bagua.torch_api.model_parallel.moe.sharded_moe.exp_selection_uniform_map :Dict[torch.device, Callable]
bagua.torch_api.model_parallel.moe.sharded_moe.gumbel_map :Dict[torch.device, Callable]
bagua.torch_api.model_parallel.moe.sharded_moe.uniform_map :Dict[torch.device, Callable]