20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136 | def __call__(
self,
batch,
is_training,
compute_loss=False,
ensemble_representations=False,
return_representations=False):
"""Run the AlphaFold model.
Arguments:
batch: Dictionary with inputs to the AlphaFold model.
is_training: Whether the system is in training or inference mode.
compute_loss: Whether to compute losses (requires extra features
to be present in the batch and knowing the true structure).
ensemble_representations: Whether to use ensembling of representations.
return_representations: Whether to also return the intermediate
representations.
Returns:
When compute_loss is True:
a tuple of loss and output of AlphaFoldIteration.
When compute_loss is False:
just output of AlphaFoldIteration.
The output of AlphaFoldIteration is a nested dictionary containing
predictions from the various heads.
"""
# SYMDESIGN - Attempt to extract a previous position passed as initialization
prev_pos = batch.pop('prev_pos', None)
impl = AlphaFoldIteration(self.config, self.global_config)
batch_size, num_residues = batch['aatype'].shape
def get_prev(ret):
new_prev = {
'prev_pos':
ret['structure_module']['final_atom_positions'],
'prev_msa_first_row': ret['representations']['msa_first_row'],
'prev_pair': ret['representations']['pair'],
}
return jax.tree_map(jax.lax.stop_gradient, new_prev)
def do_call(prev,
recycle_idx,
compute_loss=compute_loss):
if self.config.resample_msa_in_recycling:
num_ensemble = batch_size // (self.config.num_recycle + 1)
def slice_recycle_idx(x):
start = recycle_idx * num_ensemble
size = num_ensemble
return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0)
ensembled_batch = jax.tree_map(slice_recycle_idx, batch)
else:
num_ensemble = batch_size
ensembled_batch = batch
non_ensembled_batch = jax.tree_map(lambda x: x, prev)
return impl(
ensembled_batch=ensembled_batch,
non_ensembled_batch=non_ensembled_batch,
is_training=is_training,
compute_loss=compute_loss,
ensemble_representations=ensemble_representations)
prev = {}
emb_config = self.config.embeddings_and_evoformer
if emb_config.recycle_pos:
# SYMDESIGN
if prev_pos is None:
prev['prev_pos'] = jnp.zeros(
[num_residues, residue_constants.atom_type_num, 3])
else:
prev['prev_pos'] = prev_pos
# SYMDESIGN
if emb_config.recycle_features:
prev['prev_msa_first_row'] = jnp.zeros(
[num_residues, emb_config.msa_channel])
prev['prev_pair'] = jnp.zeros(
[num_residues, num_residues, emb_config.pair_channel])
if self.config.num_recycle:
if 'num_iter_recycling' in batch:
# Training time: num_iter_recycling is in batch.
# The value for each ensemble batch is the same, so arbitrarily taking
# 0-th.
num_iter = batch['num_iter_recycling'][0]
# Add insurance that we will not run more
# recyclings than the model is configured to run.
num_iter = jnp.minimum(num_iter, self.config.num_recycle)
else:
# Eval mode or tests: use the maximum number of iterations.
num_iter = self.config.num_recycle
body = lambda x: (x[0] + 1, # pylint: disable=g-long-lambda
get_prev(do_call(x[1], recycle_idx=x[0],
compute_loss=False)))
if hk.running_init():
# When initializing the Haiku module, run one iteration of the
# while_loop to initialize the Haiku modules used in `body`.
_, prev = body((0, prev))
else:
_, prev = hk.while_loop(
lambda x: x[0] < num_iter,
body,
(0, prev))
else:
num_iter = 0
ret = do_call(prev=prev, recycle_idx=num_iter)
if compute_loss:
ret = ret[0], [ret[1]]
if not return_representations:
del (ret[0] if compute_loss else ret)['representations'] # pytype: disable=unsupported-operands
return ret
|