| |
|
| | class MixerBlock(layers.Layer): |
| | def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0): |
| | super().__init__() |
| | self.seq_len = seq_len |
| | self.dim = dim |
| | self.token_mlp_dim = token_mlp_dim |
| | self.channel_mlp_dim = channel_mlp_dim |
| |
|
| | self.ln1 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32) |
| | |
| | self.token_fc1 = layers.Dense(token_mlp_dim, activation='gelu', dtype=tf.float32) |
| | self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32) |
| |
|
| | self.ln2 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32) |
| | |
| | self.channel_fc1 = layers.Dense(channel_mlp_dim, activation='gelu', dtype=tf.float32) |
| | self.channel_fc2 = layers.Dense(dim, dtype=tf.float32) |
| |
|
| | self.dropout = layers.Dropout(dropout) |
| |
|
| | def call(self, x, training=None): |
| | |
| | B = tf.shape(x)[0] |
| | L = tf.shape(x)[1] |
| | D = tf.shape(x)[2] |
| |
|
| | |
| | y = self.ln1(x) |
| | y_t = tf.transpose(y, perm=[0,2,1]) |
| | y_t = self.token_fc1(y_t) |
| | y_t = self.token_fc2(y_t) |
| | y = tf.transpose(y_t, perm=[0,2,1]) |
| | x = x + self.dropout(y, training=training) |
| |
|
| | |
| | z = self.ln2(x) |
| | z = self.channel_fc1(z) |
| | z = self.channel_fc2(z) |
| | x = x + self.dropout(z, training=training) |
| |
|
| | return x |