Skip to content

ml

proteinmpnn_factory module-attribute

proteinmpnn_factory: Annotated[ProteinMPNNFactory, 'Calling this factory method returns the single instance of the ProteinMPNN class located at the "source" keyword argument'] = ProteinMPNNFactory()

Calling this factory method returns the single instance of the Database class located at the "source" keyword argument

ProteinMPNNFactory

ProteinMPNNFactory(**kwargs)

Return a ProteinMPNN instance by calling the Factory instance with the ProteinMPNN model name

Handles creation and allotment to other processes by saving expensive memory load of multiple instances and allocating a shared pointer to the named ProteinMPNN model

Source code in symdesign/resources/ml.py
360
361
def __init__(self, **kwargs):
    self._models = {}

__call__

__call__(model_name: str = 'v_48_020', backbone_noise: float = 0.0, ca_only: bool = False, **kwargs) -> ProteinMPNN

Return the specified ProteinMPNN object singleton

Parameters:

  • model_name (str, default: 'v_48_020' ) –

    The name of the model to use from ProteinMPNN taking the format v_X_Y, where X is neighbor distance and Y is noise

  • backbone_noise (float, default: 0.0 ) –

    The amount of backbone noise to add to the pose during design

  • ca_only (bool, default: False ) –

    Whether a minimal CA variant of the protein should be used for design calculations

Returns: The instance of the initialized ProteinMPNN model

Source code in symdesign/resources/ml.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
def __call__(self, model_name: str = 'v_48_020', backbone_noise: float = 0., ca_only: bool = False, **kwargs) \
        -> ProteinMPNN:
    """Return the specified ProteinMPNN object singleton

    Args:
        model_name: The name of the model to use from ProteinMPNN taking the format v_X_Y,
            where X is neighbor distance and Y is noise
        backbone_noise: The amount of backbone noise to add to the pose during design
        ca_only: Whether a minimal CA variant of the protein should be used for design calculations
    Returns:
        The instance of the initialized ProteinMPNN model
    """
    if ca_only:
        ca = '_ca'
        if model_name == 'v_48_030':
            logger.error(f"No such ca_only model 'v_48_030'. Loading ca_only model 'v_48_020' (highest "
                         f"backbone noise ca_only model) instead")
            model_name = 'v_48_020'
        weights_dir = utils.path.protein_mpnn_ca_weights_dir
        required_memory = ca_model_memory
    else:
        ca = ''
        weights_dir = utils.path.protein_mpnn_weights_dir
        required_memory = vanilla_model_memory

    model_name_key = f'{model_name}{ca}_{backbone_noise}'
    model = self._models.get(model_name_key)
    if model:
        return model
    else:  # Create a new ProteinMPNN model instance
        # if not self._models:  # Nothing initialized
        # Acquire an adequate computing device
        if torch.cuda.is_available():
            max_memory = required_memory
            for device_int in range(torch.cuda.device_count()):
                available_memory = get_device_memory(torch.device(device_int), free=True)
                if available_memory > max_memory:
                    max_memory = available_memory
                    device_id = device_int
            try:
                device: torch.device = torch.device(device_id)
            except UnboundLocalError:  # No device has memory greater than ProteinMPNN minimum required
                device = torch.device('cpu')
            else:
                # Set the environment to use memory efficient cuda management
                max_split = 1000
                pytorch_conf = f'max_split_size_mb:{max_split},' \
                               f'roundup_power2_divisions:4,' \
                               f'garbage_collection_threshold:0.7'
                os.environ['PYTORCH_CUDA_ALLOC_CONF'] = pytorch_conf
                logger.debug(f'Setting pytorch configuration:\n{pytorch_conf}\n'
                             f'Result:{os.getenv("PYTORCH_CUDA_ALLOC_CONF")}')
        else:
            device = torch.device('cpu')

        checkpoint = torch.load(os.path.join(weights_dir, f'{model_name}.pt'), map_location=device)
        hidden_dim = 128
        num_layers = 3
        with torch.no_grad():
            model = _ProteinMPNN(num_letters=mpnn_alphabet_length,
                                 node_features=hidden_dim,
                                 edge_features=hidden_dim,
                                 hidden_dim=hidden_dim,
                                 num_encoder_layers=num_layers,
                                 num_decoder_layers=num_layers,
                                 augment_eps=backbone_noise,
                                 k_neighbors=checkpoint['num_edges'],
                                 ca_only=ca_only)
            model.to(device)
            model.load_state_dict(checkpoint['model_state_dict'])
            model.eval()
            model.device = device
            model.model_name = model_name_key

        model.log.info(f"ProteinMPNN model '{model_name_key}' on device '{device}' has "
                       f'{checkpoint["num_edges"]} edges and {checkpoint["noise_level"]} Angstroms of training '
                       'noise')
        # number_of_mpnn_model_parameters = sum([math.prod(param.size()) for param in model.parameters()])
        # logger.debug(f'The number of proteinmpnn model parameters is: {number_of_mpnn_model_parameters}')

        self._models[model_name_key] = model

    return model

get

get(**kwargs) -> ProteinMPNN

Return the specified ProteinMPNN object singleton

Returns:

  • ProteinMPNN

    The instance of the initialized ProteinMPNN model

Source code in symdesign/resources/ml.py
450
451
452
453
454
455
456
457
458
459
460
461
def get(self, **kwargs) -> ProteinMPNN:
    """Return the specified ProteinMPNN object singleton

    Keyword Args:
        model_name - str = 'v_48_020' - The name of the model to use from ProteinMPNN.
            v_X_Y where X is neighbor distance, and Y is noise
        backbone_noise - float = 0.0 - The amount of backbone noise to add to the pose during design

    Returns:
        The instance of the initialized ProteinMPNN model
    """
    return self.__call__(**kwargs)

RunModel

RunModel(config: ConfigDict, params: Optional[Mapping[str, Mapping[str, ndarray]]] = None, device: Device = None)

Bases: RunModel

Container for JAX model.

params:
device: The device the model should be compiled on
Source code in symdesign/resources/ml.py
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
def __init__(self,
             config: ml_collections.ConfigDict,
             params: Optional[Mapping[str, Mapping[str, jnp.ndarray]]] = None,
             # SYMDESIGN
             device: jax_xla.Device = None):
  """

  Args:
      config:
      params:
      device: The device the model should be compiled on
  """
  # SYMDESIGN
  self.config = config
  self.params = params
  self.multimer_mode = config.model.global_config.multimer_mode
  # SYMDESIGN
  self.parameter_map = {}
  # SYMDESIGN

  if self.multimer_mode:
    def _forward_fn(batch):
      # SYMDESIGN
      model = multimer.AlphaFoldInitialGuess(self.config.model)
      # SYMDESIGN
      return model(
          batch,
          is_training=False)
  else:
    def _forward_fn(batch):
      # SYMDESIGN
      model = monomer.AlphaFoldInitialGuess(self.config.model)
      # SYMDESIGN
      return model(
          batch,
          is_training=False,
          compute_loss=False,
          ensemble_representations=True)

  self.apply = jax.jit(hk.transform(_forward_fn).apply, device=device)
  self.init = jax.jit(hk.transform(_forward_fn).init, device=device)

predict

predict(feat: FeatureDict, random_seed: int) -> Mapping[str, Any]

Makes a prediction by inferencing the model on the provided features.

Parameters:

  • feat (FeatureDict) –

    A dictionary of NumPy feature arrays as output by RunModel.process_features.

  • random_seed (int) –

    The random seed to use when running the model. In the multimer model this controls the MSA sampling.

Returns:

  • Mapping[str, Any]

    A dictionary of model outputs.

Source code in symdesign/resources/ml.py
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
def predict(self,
            feat: features.FeatureDict,
            random_seed: int,
            ) -> Mapping[str, Any]:
  """Makes a prediction by inferencing the model on the provided features.

  Args:
    feat: A dictionary of NumPy feature arrays as output by
      RunModel.process_features.
    random_seed: The random seed to use when running the model. In the
      multimer model this controls the MSA sampling.

  Returns:
    A dictionary of model outputs.
  """
  self.init_params(feat)
  logging.info('Running predict with shape(feat) = %s',
               tree.map_structure(lambda x: x.shape, feat))
  result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)

  # This block is to ensure benchmark timings are accurate. Some blocking is
  # already happening when computing get_confidence_metrics, and this ensures
  # all outputs are blocked on.
  jax.tree_map(lambda x: x.block_until_ready(), result)
  # SYMDESIGN
  result.update(
      afmodel.get_confidence_metrics(result, multimer_mode=self.multimer_mode))
  # SYMDESIGN
  logging.info('Output shape was %s',
               tree.map_structure(lambda x: x.shape, result))
  return result

set_params

set_params(model_params: dict[str, Mapping[str, Mapping[str, ndarray]]])

Set a collection of parameters that a single compiled model should run

Parameters:

  • model_params (dict[str, Mapping[str, Mapping[str, ndarray]]]) –

    A dictionary of model parameters

Returns: None

Source code in symdesign/resources/ml.py
1192
1193
1194
1195
1196
1197
1198
1199
1200
def set_params(self, model_params: dict[str, Mapping[str, Mapping[str, jnp.ndarray]]]):
    """Set a collection of parameters that a single compiled model should run

    Args:
        model_params: A dictionary of model parameters
    Returns:
        None
    """
    self.parameter_map = model_params

predict_with_params

predict_with_params(parameter_type: str, feat: FeatureDict, random_seed: int) -> Mapping[str, Any]

Makes a prediction by inferencing the model on the provided features.

Parameters:

  • parameter_type (str) –

    The name of the parameter set to fetch

  • feat (FeatureDict) –

    A dictionary of NumPy feature arrays as output by RunModel.process_features.

  • random_seed (int) –

    The random seed to use when running the model. In the multimer model this controls the MSA sampling.

Returns:

  • Mapping[str, Any]

    A dictionary of model outputs.

Source code in symdesign/resources/ml.py
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
def predict_with_params(self, parameter_type: str,
                        feat: features.FeatureDict,
                        random_seed: int,
                        ) -> Mapping[str, Any]:
    """Makes a prediction by inferencing the model on the provided features.

    Args:
        parameter_type: The name of the parameter set to fetch
        feat: A dictionary of NumPy feature arrays as output by
            RunModel.process_features.
        random_seed: The random seed to use when running the model. In the
            multimer model this controls the MSA sampling.

    Returns:
        A dictionary of model outputs.
    """
    logging.info('Running predict with shape(feat) = %s',
                 tree.map_structure(lambda x: x.shape, feat))
    try:
        params = self.parameter_map[parameter_type]
    except KeyError:
        raise KeyError(f"The parameter_type='{parameter_type}' isn't available from the viable parameter "
                       f"sets\nCurrently available types include: {', '.join(self.parameter_map.keys())}")
    result = self.apply(params, jax.random.PRNGKey(random_seed), feat)

    # This block is to ensure benchmark timings are accurate. Some blocking is
    # already happening when computing get_confidence_metrics, and this ensures
    # all outputs are blocked on.
    jax.tree_map(lambda x: x.block_until_ready(), result)
    result.update(
        afmodel.get_confidence_metrics(result, multimer_mode=self.multimer_mode))
    logging.info('Output shape was %s',
                 tree.map_structure(lambda x: x.shape, result))
    return result

get_device_memory

get_device_memory(device: device | int | str | None, free: bool = False) -> int

Get the memory available for a requested device to calculate computational constraints

Parameters:

  • device (device | int | str | None) –

    The current device of the pytorch model in question

  • free (bool, default: False ) –

    Whether to return the free memory if the device is a cuda GPU, otherwise return all pytorch memory

Returns: The bytes of memory available

Source code in symdesign/resources/ml.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def get_device_memory(device: torch.device | int | str | None, free: bool = False) -> int:
    """Get the memory available for a requested device to calculate computational constraints

    Args:
        device: The current device of the pytorch model in question
        free: Whether to return the free memory if the device is a cuda GPU, otherwise return all pytorch memory
    Returns:
        The bytes of memory available
    """
    if not isinstance(device, torch.device):
        device = torch.device(device)

    if device.type == 'cpu':  # device is None or
        memory_constraint = utils.get_available_memory()
        logger.debug(f'The available cpu memory is: {memory_constraint}')
    else:
        free_memory, gpu_memory_total = torch.cuda.mem_get_info(device)
        logger.debug(f'The available gpu memory is: {free_memory}')
        memory_reserved = torch.cuda.memory_reserved(device)
        logger.debug(f'The reserved gpu memory is: {memory_reserved}')
        if free:
            memory_constraint = free_memory
        else:
            memory_constraint = free_memory + memory_reserved

    return memory_constraint

calculate_proteinmpnn_batch_length

calculate_proteinmpnn_batch_length(model: ProteinMPNN, number_of_residues: int, element_memory: int = 4) -> int

Parameters:

  • model (ProteinMPNN) –

    The ProteinMPNN model

  • number_of_residues (int) –

    The number of residues used in the ProteinMPNN model

  • element_memory (int, default: 4 ) –

    Where each element is np.int64, np.float32, etc.

Returns: The size of the batch that can be completed for the ProteinMPNN model given it's device

Source code in symdesign/resources/ml.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def calculate_proteinmpnn_batch_length(model: ProteinMPNN, number_of_residues: int, element_memory: int = 4) -> int:
    """

    Args:
        model: The ProteinMPNN model
        number_of_residues: The number of residues used in the ProteinMPNN model
        element_memory: Where each element is np.int64, np.float32, etc.
    Returns:
        The size of the batch that can be completed for the ProteinMPNN model given it's device
    """
    memory_constraint = get_device_memory(model.device)

    number_of_elements_available = memory_constraint // element_memory
    logger.debug(f'The number_of_elements_available is: {number_of_elements_available}')
    number_of_model_parameter_elements = sum([math.prod(param.size()) for param in model.parameters()])
    logger.debug(f'The number_of_model_parameter_elements is: {number_of_model_parameter_elements}')
    model_elements = number_of_model_parameter_elements
    # Todo use 5 as ideal CB is added by the model later with ca_only = False
    num_model_residues = 5
    model_elements += math.prod((number_of_residues, num_model_residues, 3))  # X,
    model_elements += number_of_residues  # S.shape
    model_elements += number_of_residues  # chain_mask.shape
    model_elements += number_of_residues  # chain_encoding.shape
    model_elements += number_of_residues  # residue_idx.shape
    model_elements += number_of_residues  # mask.shape
    model_elements += number_of_residues  # residue_mask.shape
    model_elements += math.prod((number_of_residues, 21))  # omit_AA_mask.shape
    model_elements += number_of_residues  # pssm_coef.shape
    model_elements += math.prod((number_of_residues, 20))  # pssm_bias.shape
    model_elements += math.prod((number_of_residues, 20))  # pssm_log_odds_mask.shape
    model_elements += number_of_residues  # tied_beta.shape
    model_elements += math.prod((number_of_residues, 21))  # bias_by_res.shape
    logger.debug(f'The number of model_elements is: {model_elements}')

    number_of_batches = number_of_elements_available // model_elements
    batch_length = number_of_batches // proteinmpnn_batch_divisor
    if batch_length == 0:
        not_enough_proteinmpnn_memory = f"Can't find a device for {model} with enough memory to complete a single " \
                                        f"batch of work with {number_of_residues} residues in the model"
        if model.device.type == 'cpu':
            raise RuntimeError(not_enough_proteinmpnn_memory)

        old_device = model.device
        # This won't work. Try to put the model on a new device
        max_memory = vanilla_model_memory
        for device_int in range(torch.cuda.device_count()):
            available_memory = get_device_memory(torch.device(device_int), free=True)
            if available_memory > max_memory:
                max_memory = available_memory
                device_id = device_int
        try:
            device: torch.device = torch.device(device_id)
        except UnboundLocalError:  # No device has memory greater than ProteinMPNN minimum required
            device = torch.device('cpu')

        if device == old_device:
            # Solve using gpu is stuck
            if device.type == 'cpu':
                # This hasn't been changed or device is cpu
                raise RuntimeError(not_enough_proteinmpnn_memory)
            else:
                # Try one more time ensuring cpu. This will be caught above if still not enough memory
                device = torch.device('cpu')

        # Set the device parameters
        model.to(device)
        model.device = device
        # Recurse
        return calculate_proteinmpnn_batch_length(model, number_of_residues, element_memory)

    return batch_length

batch_calculation

batch_calculation(size: int, batch_length: int, setup: Callable = None, compute_failure_exceptions: tuple[Type[Exception], ...] = (Exception)) -> Callable

Use as a decorator to execute a function in batches over an input that is too large for available computational resources, typically memory

Produces the variables actual_batch_length and batch_slice that can be used inside the decorated function

Parameters:

  • size (int) –

    The total number of units of work to be done

  • batch_length (int) –

    The starting length of a batch. This should be chosen empirically

  • setup (Callable, default: None ) –

    A Callable which should be called before the batches are executed to produce data that is passed to the function. The first argument of this Callable should be batch_length

  • compute_failure_exceptions (tuple[Type[Exception], ...], default: (Exception) ) –

    A tuple of possible exceptions which upon raising should be allowed to restart

Decorated Callable Args: args: The arguments to pass to the function kwargs: Keyword Arguments to pass to the decorated Callable setup_args: Arguments to pass to the setup Callable setup_kwargs: Keyword Arguments to pass to the setup Callable return_containers: dict - The key and SupportsIndex value to store decorated Callable returns inside Returns: The populated function_return_containers

Source code in symdesign/resources/ml.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def batch_calculation(size: int, batch_length: int, setup: Callable = None,
                      compute_failure_exceptions: tuple[Type[Exception], ...] = (Exception,)) -> Callable:
    """Use as a decorator to execute a function in batches over an input that is too large for available computational
    resources, typically memory

    Produces the variables actual_batch_length and batch_slice that can be used inside the decorated function

    Args:
        size: The total number of units of work to be done
        batch_length: The starting length of a batch. This should be chosen empirically
        setup: A Callable which should be called before the batches are executed to produce data that is passed to the
            function. The first argument of this Callable should be batch_length
        compute_failure_exceptions: A tuple of possible exceptions which upon raising should be allowed to restart
    Decorated Callable Args:
        args: The arguments to pass to the function
        kwargs: Keyword Arguments to pass to the decorated Callable
        setup_args: Arguments to pass to the setup Callable
        setup_kwargs: Keyword Arguments to pass to the setup Callable
        return_containers: dict - The key and SupportsIndex value to store decorated Callable returns inside
    Returns:
        The populated function_return_containers
    """
    def wrapper(func: Callable) -> Callable[[tuple[Any, ...], dict | None, tuple, dict | None, dict[str, Any]], dict]:
        if setup is None:
            def setup_(*_args, **_kwargs) -> dict:
                return {}
        else:
            setup_ = setup

        @functools.wraps(func)
        def wrapped(*args, return_containers: dict = None,
                    setup_args: tuple = tuple(), setup_kwargs: dict = None, **kwargs) -> dict:

            if return_containers is None:
                return_containers = {}

            if setup_kwargs is None:
                setup_kwargs = {}

            _batch_length = batch_length
            # finished = False
            _error = last_error = None
            while True:  # not finished:
                logger.debug(f'The batch_length is: {_batch_length}')
                try:  # The next batch_length
                    # The number_of_batches indicates how many iterations are needed to exhaust all models
                    try:
                        number_of_batches = int(ceil(size/_batch_length) or 1)  # Select at least 1
                    except ZeroDivisionError:  # We hit the minimal batch size. Report the previous error
                        if last_error is not None:  # This exited from the compute_failure_exceptions except
                            break  # break out and raise the _error
                        else:
                            raise ValueError(
                                f'The batch_length ({batch_length}) must be greater than 0')
                    # Perform any setup operations
                    # logger.critical(f'Before SETUP\nmemory_allocated: {torch.cuda.memory_allocated()}'
                    #                 f'\nmemory_reserved: {torch.cuda.memory_reserved()}')
                    setup_start = time.time()
                    setup_returns = setup_(_batch_length, *setup_args, **setup_kwargs)
                    logger.debug(f'{batch_calculation.__name__} setup function took {time.time() - setup_start:8f}s')
                    # logger.critical(f'After SETUP\nmemory_allocated: {torch.cuda.memory_allocated()}'
                    #                 f'\nmemory_reserved: {torch.cuda.memory_reserved()}')
                    batch_start = time.time()
                    for batch in range(number_of_batches):
                        # Find the upper slice limit
                        batch_slice = slice(batch * _batch_length, min((batch+1) * _batch_length, size))
                        # Perform the function, batch_slice must be used inside the func
                        logger.debug(f'Calculating batch {batch + 1}')
                        function_returns = func(batch_slice, *args, **kwargs, **setup_returns)
                        # Set the returned values in the order they were received to the precalculated return_container
                        for return_key, return_value in list(function_returns.items()):
                            try:  # To access the return_container_key in the function
                                return_containers[return_key][batch_slice] = return_value
                            except KeyError:  # If it doesn't exist
                                raise KeyError(
                                    f"Couldn't return the data specified by {return_key} to the return_container with "
                                    f"keys:{', '.join(return_containers.keys())}")
                            except ValueError as error:  # Arrays are incorrectly sized
                                raise ValueError(
                                    f"Couldn't return the data specified by {return_key} from {func.__name__} due to: "
                                    f"{error}")
                        # for return_container_key, return_container in list(return_containers.items()):
                        #     try:  # To access the return_container_key in the function
                        #         return_container[batch_slice] = function_returns[return_container_key]
                        #     except KeyError:  # If it doesn't exist
                        #         # Remove the data from the return_containers
                        #         return_containers.pop(return_container_key)

                    # Report success
                    logger.debug(f'Successful execution with batch_length of {_batch_length}. '
                                 f'Took {time.time() - batch_start:8f}s')
                    last_error = None
                    break  # finished = True
                except compute_failure_exceptions as error:
                    # del setup_returns
                    # logger.critical(f'After ERROR\nmemory_allocated: {torch.cuda.memory_allocated()}'
                    #                 f'\nmemory_reserved: {torch.cuda.memory_reserved()}')
                    # gc.collect()
                    # logger.critical(f'After GC\nmemory_allocated: {torch.cuda.memory_allocated()}'
                    #                 f'\nmemory_reserved: {torch.cuda.memory_reserved()}')
                    if _error is None:  # Set the error the first time
                        # _error = last_error = error
                        _error = last_error = traceback.format_exc()  # .format_exception(error)
                    else:
                        # last_error = error
                        last_error = traceback.format_exc()  # .format_exception(error)
                    _batch_length -= 1

            if last_error is not None:  # This exited from the ZeroDivisionError except
                # try:
                logger.critical(f'{batch_calculation.__name__} exited with the following exceptions:\n\nThe first '
                                f'exception in the traceback was the result of the first iteration, while the '
                                f'most recent exception in the traceback is last\n')
                # raise _error
                print(''.join(_error))
                # except compute_failure_exceptions:
                #     raise last_error
                print(''.join(last_error))
                raise utils.SymDesignException(
                    f"{func.__name__}() wasn't able to be executed. See the above traceback")

            return return_containers
        return wrapped
    return wrapper

create_decoding_order

create_decoding_order(randn: Tensor, chain_mask: Tensor, tied_pos: Iterable[Container] = None, to_device: str = None, **kwargs) -> Tensor

Parameters:

  • randn (Tensor) –
  • chain_mask (Tensor) –
  • tied_pos (Iterable[Container], default: None ) –
  • to_device (str, default: None ) –

Returns:

Source code in symdesign/resources/ml.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def create_decoding_order(randn: torch.Tensor, chain_mask: torch.Tensor, tied_pos: Iterable[Container] = None,
                          to_device: str = None, **kwargs) -> torch.Tensor:
    """

    Args:
        randn:
        chain_mask:
        tied_pos:
        to_device:

    Returns:

    """
    if to_device is None:
        to_device = randn.device
    # Numbers are smaller for places where chain_mask = 0.0 and higher for places where chain_mask = 1.0
    decoding_order = torch.argsort((chain_mask+0.0001) * (torch.abs(randn)))

    if tied_pos is not None:
        # Calculate the tied decoding order according to ProteinMPNN.tied_sample()
        new_decoding_order: list[list[int]] = []
        found_decoding_indices = []
        for t_dec in list(decoding_order[0].cpu().numpy()):
            if t_dec not in found_decoding_indices:
                for item in tied_pos:
                    if t_dec in item:
                        break
                else:
                    item = [t_dec]
                # Keep list of lists format
                new_decoding_order.append(item)
                # Add all found decoding_indices
                found_decoding_indices.extend(item)

        decoding_order = torch.tensor(found_decoding_indices, device=to_device)[None].repeat(len(randn), 1)

    return decoding_order

batch_proteinmpnn_input

batch_proteinmpnn_input(size: int = None, **kwargs) -> dict[str, ndarray]

Set up all data for batches of proteinmpnn design

Parameters:

  • size (int, default: None ) –

    The number of inputs to use. If left blank, the size will be inferred from axis=0 of the X array

Other Parameters:

  • X

    numpy.ndarray = None - The array specifying the parameter X

  • X_unbound

    numpy.ndarray = None - The array specifying the parameter X_unbound

  • S

    numpy.ndarray = None - The array specifying the parameter S

  • randn

    numpy.ndarray = None - The array specifying the parameter randn

  • chain_mask

    numpy.ndarray = None - The array specifying the parameter chain_mask

  • chain_encoding

    numpy.ndarray = None - The array specifying the parameter chain_encoding

  • residue_idx

    numpy.ndarray = None - The array specifying the parameter residue_idx

  • mask

    numpy.ndarray = None - The array specifying the parameter mask

  • chain_M_pos

    numpy.ndarray = None - The array specifying the parameter chain_M_pos (residue_mask)

  • omit_AA_mask

    numpy.ndarray = None - The array specifying the parameter omit_AA_mask

  • pssm_coef

    numpy.ndarray = None - The array specifying the parameter pssm_coef

  • pssm_bias

    numpy.ndarray = None - The array specifying the parameter pssm_bias

  • pssm_log_odds_mask

    numpy.ndarray = None - The array specifying the parameter pssm_log_odds_mask

  • bias_by_res

    numpy.ndarray = None - The array specifying the parameter bias_by_res

  • tied_beta

    numpy.ndarray = None - The array specifying the parameter tied_beta

Returns:

  • dict[str, ndarray]

    A dictionary with each of the ProteinMPNN parameters formatted in a batch

Source code in symdesign/resources/ml.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
def batch_proteinmpnn_input(size: int = None, **kwargs) -> dict[str, np.ndarray]:
    """Set up all data for batches of proteinmpnn design

    Args:
        size: The number of inputs to use. If left blank, the size will be inferred from axis=0 of the X array

    Keyword Args:
        X: numpy.ndarray = None - The array specifying the parameter X
        X_unbound: numpy.ndarray = None - The array specifying the parameter X_unbound
        S: numpy.ndarray = None - The array specifying the parameter S
        randn: numpy.ndarray = None - The array specifying the parameter randn
        chain_mask: numpy.ndarray = None - The array specifying the parameter chain_mask
        chain_encoding: numpy.ndarray = None - The array specifying the parameter chain_encoding
        residue_idx: numpy.ndarray = None - The array specifying the parameter residue_idx
        mask: numpy.ndarray = None - The array specifying the parameter mask
        chain_M_pos: numpy.ndarray = None - The array specifying the parameter chain_M_pos (residue_mask)
        omit_AA_mask: numpy.ndarray = None - The array specifying the parameter omit_AA_mask
        pssm_coef: numpy.ndarray = None - The array specifying the parameter pssm_coef
        pssm_bias: numpy.ndarray = None - The array specifying the parameter pssm_bias
        pssm_log_odds_mask: numpy.ndarray = None - The array specifying the parameter pssm_log_odds_mask
        bias_by_res: numpy.ndarray = None - The array specifying the parameter bias_by_res
        tied_beta: numpy.ndarray = None - The array specifying the parameter tied_beta

    Returns:
        A dictionary with each of the ProteinMPNN parameters formatted in a batch
    """
    # This is my preferred name for the chain_M_pos...
    # residue_mask: (numpy.ndarray) = None - The array specifying the parameter residue_mask of ProteinMPNN
    if size is None:  # Use X as is
        X = kwargs.get('X')
        if X is None:
            raise ValueError(
                f"{batch_proteinmpnn_input.__name__} must pass keyword argument 'X' if argument 'size' is None")
        size = len(X)
    # else:
    #     X = np.tile(X, (size,) + (1,)*X.ndim)

    # Stack ProteinMPNN sequence design task in "batches"
    device_kwargs = {}
    for key in batch_params:
        param = kwargs.pop(key, None)
        if param is not None:
            device_kwargs[key] = np.tile(param, (size,) + (1,)*param.ndim)

    # Add all kwargs that were not accessed back to the return dictionary
    device_kwargs.update(**kwargs)
    return device_kwargs

proteinmpnn_to_device

proteinmpnn_to_device(device: str = None, **kwargs) -> dict[str, Tensor]

Set up all data to torch.Tensors for ProteinMPNN design

Parameters:

  • device (str, default: None ) –

    The device to load tensors to

Other Parameters:

  • X

    numpy.ndarray = None - The array specifying the parameter X

  • X_unbound

    numpy.ndarray = None - The array specifying the parameter X_unbound

  • S

    numpy.ndarray = None - The array specifying the parameter S

  • randn

    numpy.ndarray = None - The array specifying the parameter randn

  • chain_mask

    numpy.ndarray = None - The array specifying the parameter chain_mask

  • chain_encoding

    numpy.ndarray = None - The array specifying the parameter chain_encoding

  • residue_idx

    numpy.ndarray = None - The array specifying the parameter residue_idx

  • mask

    numpy.ndarray = None - The array specifying the parameter mask

  • chain_M_pos

    numpy.ndarray = None - The array specifying the parameter chain_M_pos (residue_mask)

  • omit_AA_mask

    numpy.ndarray = None - The array specifying the parameter omit_AA_mask

  • pssm_coef

    numpy.ndarray = None - The array specifying the parameter pssm_coef

  • pssm_bias

    numpy.ndarray = None - The array specifying the parameter pssm_bias

  • pssm_log_odds_mask

    numpy.ndarray = None - The array specifying the parameter pssm_log_odds_mask

  • bias_by_res

    numpy.ndarray = None - The array specifying the parameter bias_by_res

  • tied_beta

    numpy.ndarray = None - The array specifying the parameter tied_beta

Returns:

  • dict[str, Tensor]

    The torch.Tensor ProteinMPNN parameters

Source code in symdesign/resources/ml.py
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
def proteinmpnn_to_device(device: str = None, **kwargs) -> dict[str, torch.Tensor]:
    """Set up all data to torch.Tensors for ProteinMPNN design

    Args:
        device: The device to load tensors to

    Keyword Args:
        X: numpy.ndarray = None - The array specifying the parameter X
        X_unbound: numpy.ndarray = None - The array specifying the parameter X_unbound
        S: numpy.ndarray = None - The array specifying the parameter S
        randn: numpy.ndarray = None - The array specifying the parameter randn
        chain_mask: numpy.ndarray = None - The array specifying the parameter chain_mask
        chain_encoding: numpy.ndarray = None - The array specifying the parameter chain_encoding
        residue_idx: numpy.ndarray = None - The array specifying the parameter residue_idx
        mask: numpy.ndarray = None - The array specifying the parameter mask
        chain_M_pos: numpy.ndarray = None - The array specifying the parameter chain_M_pos (residue_mask)
        omit_AA_mask: numpy.ndarray = None - The array specifying the parameter omit_AA_mask
        pssm_coef: numpy.ndarray = None - The array specifying the parameter pssm_coef
        pssm_bias: numpy.ndarray = None - The array specifying the parameter pssm_bias
        pssm_log_odds_mask: numpy.ndarray = None - The array specifying the parameter pssm_log_odds_mask
        bias_by_res: numpy.ndarray = None - The array specifying the parameter bias_by_res
        tied_beta: numpy.ndarray = None - The array specifying the parameter tied_beta

    Returns:
        The torch.Tensor ProteinMPNN parameters
    """
    if device is None:
        raise ValueError('Must provide the desired device to load proteinmpnn')
    logger.debug(f'Loading ProteinMPNN parameters to device: {device}')

    # Convert all numpy arrays to pytorch
    device_kwargs = {}
    for key, dtype in dtype_map.items():
        param = kwargs.pop(key, None)
        if param is not None:
            device_kwargs[key] = torch.from_numpy(param).to(dtype=dtype, device=device)

    # Add all kwargs that were not accessed back to the return dictionary
    device_kwargs.update(**kwargs)
    return device_kwargs

setup_pose_batch_for_proteinmpnn

setup_pose_batch_for_proteinmpnn(batch_length: int, device, **parameters) -> dict[str, ndarray | Tensor]

Parameters:

  • batch_length (int) –

    The length the batch to set up

  • device

    The device used for batch calculations

Returns: A mapping of necessary containers for ProteinMPNN inference in batches and loaded to the device

Source code in symdesign/resources/ml.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
@torch.no_grad()  # Ensure no gradients are produced
def setup_pose_batch_for_proteinmpnn(batch_length: int, device, **parameters) -> dict[str, np.ndarray | torch.Tensor]:
    """

    Args:
        batch_length: The length the batch to set up
        device: The device used for batch calculations
    Returns:
        A mapping of necessary containers for ProteinMPNN inference in batches and loaded to the device
    """
    # batch_length = batch_slice.stop - batch_slice.start
    # Create batch_length fixed parameter data which are the same across poses
    batch_parameters: dict[str, np.ndarray | torch.Tensor] = \
        batch_proteinmpnn_input(size=batch_length, **parameters)
    # Move fixed data structures to the model device
    # Update parameters as some are not transferred to the identified device
    batch_parameters.update(proteinmpnn_to_device(device, **batch_parameters))

    return batch_parameters

proteinmpnn_batch_design

proteinmpnn_batch_design(batch_slice: slice, proteinmpnn: ProteinMPNN, X: Tensor = None, randn: Tensor = None, S: Tensor = None, chain_mask: Tensor = None, chain_encoding: Tensor = None, residue_idx: Tensor = None, mask: Tensor = None, temperatures: Sequence[float] = (0.1), pose_length: int = None, bias_by_res: Tensor = None, tied_pos: Iterable[Container] = None, X_unbound: Tensor = None, **batch_parameters) -> dict[str, ndarray]

Perform ProteinMPNN design tasks on input that is split into batches

Parameters:

  • batch_slice (slice) –
  • proteinmpnn (ProteinMPNN) –
  • X (Tensor, default: None ) –
  • randn (Tensor, default: None ) –
  • S (Tensor, default: None ) –
  • chain_mask (Tensor, default: None ) –
  • chain_encoding (Tensor, default: None ) –
  • residue_idx (Tensor, default: None ) –
  • mask (Tensor, default: None ) –
  • temperatures (Sequence[float], default: (0.1) ) –
  • pose_length (int, default: None ) –
  • bias_by_res (Tensor, default: None ) –
  • tied_pos (Iterable[Container], default: None ) –
  • X_unbound (Tensor, default: None ) –

Returns: A mapping of the key describing to the corresponding value, i.e. sequences, complex_sequence_loss, and unbound_sequence_loss

Source code in symdesign/resources/ml.py
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
def proteinmpnn_batch_design(batch_slice: slice, proteinmpnn: ProteinMPNN,
                             X: torch.Tensor = None,
                             randn: torch.Tensor = None,
                             S: torch.Tensor = None,
                             chain_mask: torch.Tensor = None,
                             chain_encoding: torch.Tensor = None,
                             residue_idx: torch.Tensor = None,
                             mask: torch.Tensor = None,
                             temperatures: Sequence[float] = (0.1,),
                             pose_length: int = None,
                             bias_by_res: torch.Tensor = None,
                             tied_pos: Iterable[Container] = None,
                             X_unbound: torch.Tensor = None,
                             **batch_parameters
                             ) -> dict[str, np.ndarray]:
    """Perform ProteinMPNN design tasks on input that is split into batches

    Args:
        batch_slice:
        proteinmpnn:
        X:
        randn:
        S:
        chain_mask:
        chain_encoding:
        residue_idx:
        mask:
        temperatures:
        pose_length:
        bias_by_res:
        tied_pos:
        X_unbound:
    Returns:
        A mapping of the key describing to the corresponding value, i.e. sequences, complex_sequence_loss, and
            unbound_sequence_loss
    """
    # X = batch_parameters.pop('X', None)
    # S = batch_parameters.pop('S', None)
    # chain_mask = batch_parameters.pop('chain_mask', None)
    # chain_encoding = batch_parameters.pop('chain_encoding', None)
    # residue_idx = batch_parameters.pop('residue_idx', None)
    # mask = batch_parameters.pop('mask', None)
    # randn = batch_parameters.pop('randn', None)
    # # omit_AAs_np = batch_parameters.get('omit_AAs_np', None)
    # # bias_AAs_np = batch_parameters.get('bias_AAs_np', None)
    residue_mask = batch_parameters.pop('chain_M_pos', None)  # name change makes more sense
    # # omit_AA_mask = batch_parameters.get('omit_AA_mask', None)
    # # pssm_coef = batch_parameters.get('pssm_coef', None)
    # # pssm_bias = batch_parameters.get('pssm_bias', None)
    # # pssm_multi = batch_parameters.get('pssm_multi', None)
    # # pssm_log_odds_flag = batch_parameters.get('pssm_log_odds_flag', None)
    # # pssm_log_odds_mask = batch_parameters.get('pssm_log_odds_mask', None)
    # # pssm_bias_flag = batch_parameters.get('pssm_bias_flag', None)
    # tied_pos = batch_parameters.pop('tied_pos', None)
    # # tied_beta = batch_parameters.pop('tied_beta', None)
    # # bias_by_res = batch_parameters.get('bias_by_res', None)

    actual_batch_length = batch_slice.stop - batch_slice.start
    # # Clone the data from the sequence tensor so that it can be set with the null token below
    # S_design = S.detach().clone()
    if pose_length is None:
        batch_length, pose_length, *_ = S.shape
    else:
        batch_length, *_ = S.shape

    if actual_batch_length != batch_length:
        # Slice these for the last iteration
        X = X[:actual_batch_length]  # , None)
        chain_mask = chain_mask[:actual_batch_length]  # , None)
        chain_encoding = chain_encoding[:actual_batch_length]  # , None)
        residue_idx = residue_idx[:actual_batch_length]  # , None)
        mask = mask[:actual_batch_length]  # , None)
        bias_by_res = bias_by_res[:actual_batch_length]  # , None)
        randn = randn[:actual_batch_length]
        residue_mask = residue_mask[:actual_batch_length]
        S = S[:actual_batch_length]  # , None)
        # S_design = S_design[:actual_batch_length]  # , None)
        # Unpack, unpacked keyword args
        omit_AA_mask = batch_parameters.get('omit_AA_mask')
        pssm_coef = batch_parameters.get('pssm_coef')
        pssm_bias = batch_parameters.get('pssm_bias')
        pssm_log_odds_mask = batch_parameters.get('pssm_log_odds_mask')
        # Set keyword args
        batch_parameters['omit_AA_mask'] = omit_AA_mask[:actual_batch_length]
        batch_parameters['pssm_coef'] = pssm_coef[:actual_batch_length]
        batch_parameters['pssm_bias'] = pssm_bias[:actual_batch_length]
        batch_parameters['pssm_log_odds_mask'] = pssm_log_odds_mask[:actual_batch_length]
        try:
            X_unbound = X_unbound[:actual_batch_length]  # , None)
        except TypeError:  # Can't slice NoneType
            pass

    # # Use the sequence as an unknown token then guess the probabilities given the remaining
    # # information, i.e. the sequence and the backbone
    # S_design_null[residue_mask.type(torch.bool)] = MPNN_NULL_IDX
    chain_residue_mask = chain_mask * residue_mask

    batch_sequences = []
    _per_residue_complex_sequence_loss = []
    _per_residue_unbound_sequence_loss = []
    number_of_temps = len(temperatures)
    for temp_idx, temperature in enumerate(temperatures):
        sample_start_time = time.time()
        if tied_pos is None:
            sample_dict = proteinmpnn.sample(X, randn, S, chain_mask, chain_encoding, residue_idx, mask,
                                             chain_M_pos=residue_mask, temperature=temperature, bias_by_res=bias_by_res,
                                             **batch_parameters)
        else:
            sample_dict = proteinmpnn.tied_sample(X, randn, S, chain_mask, chain_encoding, residue_idx,
                                                  mask, chain_M_pos=residue_mask, temperature=temperature,
                                                  bias_by_res=bias_by_res, tied_pos=tied_pos, **batch_parameters)
        proteinmpnn.log.info(f'Sample calculation took {time.time() - sample_start_time:8f}s')

        # Format outputs - All have at lease shape (batch_length, model_length,)
        S_sample = sample_dict['S']
        _batch_sequences = S_sample[:, :pose_length]
        # Check for null sequence output
        null_seq = _batch_sequences == 20
        # null_indices = np.argwhere(null_seq == 1)
        # if null_indices.nelement():  # Not an empty tensor...
        # Find the indices that are null on each axis
        null_design_indices, null_sequence_indices = torch.nonzero(null_seq == 1, as_tuple=True)
        if null_design_indices.nelement():  # Not an empty tensor...
            proteinmpnn.log.warning(f'Found null sequence output... Resampling selected positions')
            proteinmpnn.log.debug(f'At sequence position(s): {null_sequence_indices}')
            null_seq = (False,)
            sampled_probs = sample_dict['probs'].cpu()
            while not all(null_seq):  # null_indices.nelement():  # Not an empty tensor...
                # _decoding_order = decoding_order.cpu().numpy()[:, :pose_length] / 12  # Hard coded symmetry divisor
                # # (batch_length, pose_length)
                # print(f'Shape of null_seq: {null_seq.shape}')
                # print(f'Shape of _decoding_order: {_decoding_order.shape}')
                # print(f'Shape of _batch_sequences: {_batch_sequences.shape}')
                # print(f'Found the decoding sites with a null output: {_decoding_order[null_seq]}')
                # print(f'Found the probabilities with a null output: {_probabilities[null_seq]}')
                # print(_batch_sequences.numpy()[_decoding_order])
                # _probabilities = sample_dict['probs']  # .cpu().numpy()[:, :pose_length]
                # _probabilities with shape (batch_length, model_length, mpnn_alphabet_length)
                new_amino_acid_types = \
                    torch.multinomial(sampled_probs[null_design_indices, null_sequence_indices],
                                      1).squeeze(-1)
                # _batch_sequences[null_indices] = new_amino_acid_type
                # new_amino_acid_type = torch.multinomial(sample_dict['probs'][null_seq], 1)
                # _batch_sequences[null_seq] = new_amino_acid_type
                null_seq = new_amino_acid_types != 20
                # null_seq = _batch_sequences == 20
                # null_indices = np.argwhere(null_seq == 1)
            else:
                # Set the
                _batch_sequences[null_design_indices, null_sequence_indices] = new_amino_acid_types
            # proteinmpnn.log.debug('Fixed null sequence elements')

        decoding_order = sample_dict['decoding_order']
        # decoding_order_out = decoding_order  # When using the same decoding order for all
        log_probs_start_time = time.time()
        if X_unbound is not None:
            # unbound_log_prob_start_time = time.time()
            # logger.critical(f'Starting unbound calc: '
            #                 f'available memory={get_device_memory(proteinmpnn.device)/gb_divisor}')
            unbound_log_probs = \
                proteinmpnn(X_unbound, S_sample, mask, chain_residue_mask, residue_idx, chain_encoding,
                            None,  # This argument is provided but with below args, is not used
                            use_input_decoding_order=True, decoding_order=decoding_order)
            # logger.critical(f'After unbound calc: '
            #                 f'available memory={get_device_memory(proteinmpnn.device)/gb_divisor}')
            _per_residue_unbound_sequence_loss.append(
                sequence_nllloss(_batch_sequences, unbound_log_probs[:, :pose_length]).cpu().numpy())
            # logger.debug(f'Unbound log probabilities calculation took '
            #              f'{time.time() - unbound_log_prob_start_time:8f}s')

        # logger.critical(f'Starting bound calc: '
        #                 f'available memory={get_device_memory(proteinmpnn.device) / gb_divisor}')
        complex_log_probs = \
            proteinmpnn(X, S_sample, mask, chain_residue_mask, residue_idx, chain_encoding,
                        None,  # This argument is provided but with below args, is not used
                        use_input_decoding_order=True, decoding_order=decoding_order)
        # logger.critical(f'After bound calc: '
        #                 f'available memory={get_device_memory(proteinmpnn.device) / gb_divisor}')
        # complex_log_probs is
        # tensor([[[-2.7691, -3.5265, -2.9001,  ..., -3.3623, -3.0247, -4.2772],
        #          [-2.7691, -3.5265, -2.9001,  ..., -3.3623, -3.0247, -4.2772],
        #          [-2.7691, -3.5265, -2.9001,  ..., -3.3623, -3.0247, -4.2772],
        #          ...,
        #          [-2.7691, -3.5265, -2.9001,  ..., -3.3623, -3.0247, -4.2772],
        #          [-2.7691, -3.5265, -2.9001,  ..., -3.3623, -3.0247, -4.2772],
        #          [-2.7691, -3.5265, -2.9001,  ..., -3.3623, -3.0247, -4.2772]],
        #         [[-2.6934, -4.0610, -2.6506, ..., -4.2404, -3.4620, -4.8641],
        #          [-2.8753, -4.3959, -2.4042,  ..., -4.4922, -3.5962, -5.1403],
        #          [-2.5235, -4.0181, -2.7738,  ..., -4.2454, -3.4768, -4.8088],
        #          ...,
        #          [-3.4500, -4.4373, -3.7814,  ..., -5.1637, -4.6107, -5.2295],
        #          [-0.9690, -4.9492, -3.9373,  ..., -2.0154, -2.2262, -4.3334],
        #          [-3.1118, -4.3809, -3.8763,  ..., -4.7145, -4.1524, -5.3076]]])
        # Score the redesigned structure-sequence
        # mask_for_loss = chain_mask_and_mask*residue_mask
        # batch_scores = sequence_nllloss(S_sample, complex_log_probs, mask_for_loss, per_residue=False)
        # batch_scores is
        # tensor([2.1039, 2.0618, 2.0802, 2.0538, 2.0114, 2.0002], device='cuda:0')
        _per_residue_complex_sequence_loss.append(
            sequence_nllloss(_batch_sequences, complex_log_probs[:, :pose_length]).cpu().numpy())
        proteinmpnn.log.info(f'Log probabilities score calculation took {time.time() - log_probs_start_time:8f}s')
        batch_sequences.append(_batch_sequences.cpu())

    # Reshape data structures to have shape (batch_length, number_of_temperatures, pose_length)
    _residue_indices_of_interest = residue_mask[:, :pose_length].cpu().numpy().astype(bool)
    sequences = np.concatenate(batch_sequences, axis=1).reshape(actual_batch_length, number_of_temps, pose_length)
    complex_sequence_loss =\
        np.concatenate(_per_residue_complex_sequence_loss, axis=1) \
        .reshape(actual_batch_length, number_of_temps, pose_length)
    if X_unbound is not None:
        unbound_sequence_loss = \
            np.concatenate(_per_residue_unbound_sequence_loss, axis=1) \
            .reshape(actual_batch_length, number_of_temps, pose_length)
    else:
        unbound_sequence_loss = np.empty_like(complex_sequence_loss)

    return {'sequences': sequences,
            'proteinmpnn_loss_complex': complex_sequence_loss,
            'proteinmpnn_loss_unbound': unbound_sequence_loss,
            'design_indices': _residue_indices_of_interest}

proteinmpnn_batch_score

proteinmpnn_batch_score(batch_slice: slice, proteinmpnn: ProteinMPNN, X: Tensor = None, S: Tensor = None, chain_mask: Tensor = None, chain_encoding: Tensor = None, residue_idx: Tensor = None, mask: Tensor = None, pose_length: int = None, X_unbound: Tensor = None, chain_M_pos: Tensor = None, residue_mask: Tensor = None, randn: Tensor = None, decoding_order: Tensor = None, **batch_parameters) -> dict[str, ndarray]

Perform ProteinMPNN design tasks on input that is split into batches

Parameters:

  • batch_slice (slice) –
  • proteinmpnn (ProteinMPNN) –
  • X (Tensor, default: None ) –
  • S (Tensor, default: None ) –
  • chain_mask (Tensor, default: None ) –
  • chain_encoding (Tensor, default: None ) –
  • residue_idx (Tensor, default: None ) –
  • mask (Tensor, default: None ) –
  • pose_length (int, default: None ) –
  • X_unbound (Tensor, default: None ) –
  • chain_M_pos (Tensor, default: None ) –
  • residue_mask (Tensor, default: None ) –
  • randn (Tensor, default: None ) –
  • decoding_order (Tensor, default: None ) –

Returns: A mapping of the key describing to the corresponding value, i.e. sequences, complex_sequence_loss, and unbound_sequence_loss

Source code in symdesign/resources/ml.py
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
def proteinmpnn_batch_score(batch_slice: slice, proteinmpnn: ProteinMPNN,
                            X: torch.Tensor = None,
                            S: torch.Tensor = None,
                            chain_mask: torch.Tensor = None,
                            chain_encoding: torch.Tensor = None,
                            residue_idx: torch.Tensor = None,
                            mask: torch.Tensor = None,
                            pose_length: int = None,
                            X_unbound: torch.Tensor = None,
                            chain_M_pos: torch.Tensor = None,
                            residue_mask: torch.Tensor = None,
                            randn: torch.Tensor = None,
                            decoding_order: torch.Tensor = None,
                            **batch_parameters
                            ) -> dict[str, np.ndarray]:
    """Perform ProteinMPNN design tasks on input that is split into batches

    Args:
        batch_slice:
        proteinmpnn:
        X:
        S:
        chain_mask:
        chain_encoding:
        residue_idx:
        mask:
        pose_length:
        X_unbound:
        chain_M_pos:
        residue_mask:
        randn:
        decoding_order:
    Returns:
        A mapping of the key describing to the corresponding value, i.e. sequences, complex_sequence_loss, and
            unbound_sequence_loss
    """
    if chain_M_pos is not None:
        residue_mask = chain_M_pos  # Name change makes more sense
    elif residue_mask is not None:
        pass
    else:
        raise ValueError("Must pass either 'residue_mask' or 'chain_M_pos'")

    if pose_length is None:
        batch_length, pose_length, *_ = X.shape
    else:
        batch_length, *_ = X.shape

    actual_batch_length = batch_slice.stop - batch_slice.start

    # Slice the sequence according to those that are currently batched for scoring
    S = S[batch_slice]  # , None)
    if actual_batch_length != batch_length:
        # Slice these for the last iteration
        X = X[:actual_batch_length]  # , None)
        chain_mask = chain_mask[:actual_batch_length]  # , None)
        chain_encoding = chain_encoding[:actual_batch_length]  # , None)
        residue_idx = residue_idx[:actual_batch_length]  # , None)
        mask = mask[:actual_batch_length]  # , None)
        # randn = randn[:actual_batch_length]
        residue_mask = residue_mask[:actual_batch_length]
        try:
            X_unbound = X_unbound[:actual_batch_length]  # , None)
        except TypeError:  # Can't slice NoneType
            pass

    # logger.debug(f'S shape: {S.shape}')
    # logger.debug(f'X shape: {X.shape}')
    # # logger.debug(f'chain_mask shape: {chain_mask.shape}')
    # logger.debug(f'chain_encoding shape: {chain_encoding.shape}')
    # logger.debug(f'residue_idx shape: {residue_idx.shape}')
    # logger.debug(f'mask shape: {mask.shape}')
    # # logger.debug(f'residue_mask shape: {residue_mask.shape}')

    chain_residue_mask = chain_mask * residue_mask
    # logger.debug(f'chain_residue_mask shape: {chain_residue_mask.shape}')

    # Score and format outputs - All have at lease shape (batch_length, model_length,)
    if decoding_order is not None:
        # logger.debug(f'decoding_order shape: {decoding_order.shape}, type: {decoding_order.dtype}')
        decoding_order = decoding_order[:actual_batch_length]
        provided_decoding_order = True
        randn = None
    elif randn is not None:
        # logger.debug(f'decoding_order shape: {randn.shape}, type: {randn.dtype}')
        randn = randn[:actual_batch_length]
        decoding_order = None
        provided_decoding_order = False
    else:
        # Todo generate a randn fresh?
        raise ValueError("Missing required argument 'randn' or 'decoding_order'")

    # decoding_order_out = decoding_order  # When using the same decoding order for all
    log_probs_start_time = time.time()

    # Todo debug the input Tensor. Most likely the sequence must be (batch, pose, aa?)
    # RuntimeError: Index tensor must have the same number of dimensions as input tensor
    complex_log_probs = \
        proteinmpnn(X, S, mask, chain_residue_mask, residue_idx, chain_encoding, randn,
                    use_input_decoding_order=provided_decoding_order, decoding_order=decoding_order)
    per_residue_complex_sequence_loss = \
        sequence_nllloss(S[:, :pose_length], complex_log_probs[:, :pose_length]).cpu().numpy()

    # Reshape data structures to have shape (batch_length, number_of_temperatures, pose_length)
    # _residue_indices_of_interest = residue_mask[:, :pose_length].cpu().numpy().astype(bool)
    # sequences = np.concatenate(batch_sequences, axis=1).reshape(actual_batch_length, number_of_temps, pose_length)
    # complex_sequence_loss = \
    #     np.concatenate(per_residue_complex_sequence_loss, axis=1)\
    #     .reshape(actual_batch_length, number_of_temps, pose_length)
    # if X_unbound is not None:
    #     unbound_sequence_loss = \
    #         np.concatenate(per_residue_unbound_sequence_loss, axis=1)\
    #         .reshape(actual_batch_length, number_of_temps, pose_length)
    # else:
    #     unbound_sequence_loss = np.empty_like(complex_sequence_loss)
    if X_unbound is not None:
        # unbound_log_prob_start_time = time.time()
        unbound_log_probs = \
            proteinmpnn(X_unbound, S, mask, chain_residue_mask, residue_idx, chain_encoding, randn,
                        use_input_decoding_order=provided_decoding_order, decoding_order=decoding_order)
        per_residue_unbound_sequence_loss = \
            sequence_nllloss(S[:, :pose_length], unbound_log_probs[:, :pose_length]).cpu().numpy()
        # logger.debug(f'Unbound log probabilities calculation took '
        #              f'{time.time() - unbound_log_prob_start_time:8f}s')
    else:
        per_residue_unbound_sequence_loss = np.empty_like(complex_log_probs)
    proteinmpnn.log.info(f'Log probabilities score calculation took {time.time() - log_probs_start_time:8f}s')

    return {'proteinmpnn_loss_complex': per_residue_complex_sequence_loss,
            'proteinmpnn_loss_unbound': per_residue_unbound_sequence_loss}

sequence_nllloss

sequence_nllloss(sequence: Tensor, log_probs: Tensor, mask: Tensor = None, per_residue: bool = True) -> Tensor

Score designed sequences using the Negative log likelihood loss function

Parameters:

  • sequence (Tensor) –

    The sequence tensor

  • log_probs (Tensor) –

    The logarithmic probabilities at each residue for every amino acid. This may be found by an evolutionary profile or a forward pass through ProteinMPNN

  • mask (Tensor, default: None ) –

    Any positions that are masked in the design task

  • per_residue (bool, default: True ) –

    Whether to return scores per residue

Returns: The loss calculated over the log probabilities compared to the sequence tensor. If per_residue=True, the returned Tensor is the same shape as sequence (i.e. (batch, length)), otherwise, it is just the length of sequence as calculated by the average loss over every residue

Source code in symdesign/resources/ml.py
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
def sequence_nllloss(sequence: torch.Tensor, log_probs: torch.Tensor,
                     mask: torch.Tensor = None, per_residue: bool = True) -> torch.Tensor:
    """Score designed sequences using the Negative log likelihood loss function

    Args:
        sequence: The sequence tensor
        log_probs: The logarithmic probabilities at each residue for every amino acid.
            This may be found by an evolutionary profile or a forward pass through ProteinMPNN
        mask: Any positions that are masked in the design task
        per_residue: Whether to return scores per residue
    Returns:
        The loss calculated over the log probabilities compared to the sequence tensor.
            If per_residue=True, the returned Tensor is the same shape as sequence (i.e. (batch, length)),
            otherwise, it is just the length of sequence as calculated by the average loss over every residue
    """
    criterion = torch.nn.NLLLoss(reduction='none')
    # Measure log_probs loss with respect to the sequence. Make each sequence and log probs stacked along axis=0
    loss = criterion(
        log_probs.contiguous().view(-1, log_probs.size(-1)),
        sequence.contiguous().view(-1)
    ).view(sequence.size())  # Revert the shape to the original sequence shape
    # Take the average over every designed position and return the single score
    if per_residue:
        return loss
    elif mask is None:
        return torch.sum(loss, dim=-1)
    else:
        return torch.sum(loss*mask, dim=-1) / torch.sum(mask, dim=-1)

jnp_to_np

jnp_to_np(jax_dict: dict[str, Any]) -> dict[str, Any]

Recursively changes jax arrays to numpy arrays

Parameters:

  • jax_dict (dict[str, Any]) –

    A dictionary with the keys mapped to jax.numpy.array types

Returns: The input dictionary modified with the keys mapped to np.array type

Source code in symdesign/resources/ml.py
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
def jnp_to_np(jax_dict: dict[str, Any]) -> dict[str, Any]:
    """Recursively changes jax arrays to numpy arrays

    Args:
        jax_dict: A dictionary with the keys mapped to jax.numpy.array types
    Returns:
        The input dictionary modified with the keys mapped to np.array type
    """
    for k, v in jax_dict.items():
        if isinstance(v, dict):
            jax_dict[k] = jnp_to_np(v)
        elif isinstance(v, jnp.ndarray):
            jax_dict[k] = np.array(v)
    return jax_dict

calculate_alphafold_batch_length

calculate_alphafold_batch_length(device: Device, number_of_residues: int, element_memory: int = 4) -> int

Parameters:

  • device (Device) –

    The ProteinMPNN model

  • number_of_residues (int) –

    The number of residues used in the ProteinMPNN model

  • element_memory (int, default: 4 ) –

    Where each element is np.int64, np.float32, etc.

Returns: The size of the batch that can be completed for the ProteinMPNN model given it's device

Source code in symdesign/resources/ml.py
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
def calculate_alphafold_batch_length(device: jax_xla.Device, number_of_residues: int, element_memory: int = 4) -> int:
    """

    Args:
        device: The ProteinMPNN model
        number_of_residues: The number of residues used in the ProteinMPNN model
        element_memory: Where each element is np.int64, np.float32, etc.
    Returns:
        The size of the batch that can be completed for the ProteinMPNN model given it's device
    """
    memory_constraint = get_device_memory(device)

    number_of_elements_available = memory_constraint // element_memory
    logger.debug(f'The number_of_elements_available is: {number_of_elements_available}')
    number_of_model_parameter_elements = sum([math.prod(param.size()) for param in model.parameters()])
    logger.debug(f'The number_of_model_parameter_elements is: {number_of_model_parameter_elements}')
    model_elements = number_of_model_parameter_elements
    # Todo use 5 as ideal CB is added by the model later with ca_only = False
    num_model_residues = 5
    model_elements += math.prod((number_of_residues, num_model_residues, 3))  # X,
    model_elements += number_of_residues  # S.shape
    model_elements += number_of_residues  # chain_mask.shape
    model_elements += number_of_residues  # chain_encoding.shape
    model_elements += number_of_residues  # residue_idx.shape
    model_elements += number_of_residues  # mask.shape
    model_elements += number_of_residues  # residue_mask.shape
    model_elements += math.prod((number_of_residues, 21))  # omit_AA_mask.shape
    model_elements += number_of_residues  # pssm_coef.shape
    model_elements += math.prod((number_of_residues, 20))  # pssm_bias.shape
    model_elements += math.prod((number_of_residues, 20))  # pssm_log_odds_mask.shape
    model_elements += number_of_residues  # tied_beta.shape
    model_elements += math.prod((number_of_residues, 21))  # bias_by_res.shape
    logger.debug(f'The number of model_elements is: {model_elements}')

    number_of_batches = number_of_elements_available // model_elements
    batch_length = number_of_batches // proteinmpnn_batch_divisor
    if batch_length == 0:
        not_enough_proteinmpnn_memory = f"Can't find a device for {model} with enough memory to complete a single " \
                                        f"batch of work with {number_of_residues} residues in the model"
        if device.platform == 'cpu':
            raise RuntimeError(not_enough_proteinmpnn_memory)

        old_device = device
        # This won't work. Try to put the model on a new device
        max_memory = vanilla_model_memory
        for device_int in range(torch.cuda.device_count()):
            available_memory = get_device_memory(torch.device(device_int), free=True)
            if available_memory > max_memory:
                max_memory = available_memory
                device_id = device_int
        try:
            device: torch.device = torch.device(device_id)
        except UnboundLocalError:  # No device has memory greater than ProteinMPNN minimum required
            device = jax.devices('cpu')[0]

        if device == old_device:
            # Solve using gpu is stuck
            if device.type == 'cpu':
                # This hasn't been changed or device is cpu
                raise RuntimeError(not_enough_proteinmpnn_memory)
            else:
                # Try one more time ensuring cpu. This will be caught above if still not enough memory
                device = torch.device('cpu')

        # Recurse
        return calculate_proteinmpnn_batch_length(model, number_of_residues, element_memory)

    return device

get_jax_device_memory

get_jax_device_memory(device_int: int) -> int

Based on the device number, use torch to get the device memory

Source code in symdesign/resources/ml.py
1310
1311
1312
def get_jax_device_memory(device_int: int) -> int:  # jax_xla.Device
    """Based on the device number, use torch to get the device memory"""
    return get_device_memory(device_int)

alphafold_required_memory

alphafold_required_memory(number_of_residues: int)

Get the bytes required for the number of residues in the model

Parameters:

  • number_of_residues (int) –

    The number of residues in the model

Source code in symdesign/resources/ml.py
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
def alphafold_required_memory(number_of_residues: int):
    """Get the bytes required for the number of residues in the model

    Args:
        number_of_residues: The number of residues in the model
    """
    if number_of_residues > 100:
        return 629145600  # '600 M'
    elif number_of_residues > 200:
        return 1258291200  # '1200 M'
    elif number_of_residues > 400:
        return 2516582400  # '2400 M'
    elif number_of_residues > 800:
        return 5033164800  # '4800 M'
    elif number_of_residues > 1000:
        return 6291456000  # '6000 M'
    else:  # Assume 2000+
        return 12582912000  # '12000 M'

get_alphafold_model_device

get_alphafold_model_device(number_of_residues: int) -> Device

Get the GPU capable of performing the AlphaFold inference for the number of residues in the model

Parameters:

  • number_of_residues (int) –

    The number of residues in the model

Returns: The jax.Device to use with the number of residues present in the model

Source code in symdesign/resources/ml.py
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
def get_alphafold_model_device(number_of_residues: int) -> jax_xla.Device:
    """Get the GPU capable of performing the AlphaFold inference for the number of residues in the model

    Args:
        number_of_residues: The number of residues in the model
    Returns:
        The jax.Device to use with the number of residues present in the model
    """
    use_device = None
    max_memory = alphafold_required_memory(number_of_residues)
    for int, device in enumerate(jax.devices('gpu')):
        available_memory = get_jax_device_memory(int)
        if available_memory > max_memory:
            max_memory = available_memory
            use_device = device

    if use_device is None:
        raise RuntimeError(
            f"Couldn't find a usable device with the memory requirement {max_memory}")
    else:
        return use_device

set_up_model_runners

set_up_model_runners(model_type: af_model_literal = 'monomer', number_of_residues: int = 1000, num_predictions_per_model: int = 1, num_ensemble: int = 1, development: bool = False) -> dict[str, RunModel]

Produce Alphafold RunModel class loaded with their training parameters

Parameters:

  • model_type (af_model_literal, default: 'monomer' ) –

    The type of model to load. Should be one of the viable Alphafold models including: 'monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'

  • number_of_residues (int, default: 1000 ) –

    The number of residues in the model. Used only to calculate approximate memory needs during device allocation

  • num_predictions_per_model (int, default: 1 ) –

    The number of predictions to make for each Alphafold model. Essentially duplicates the original models 'num_predictions_per_model' times

  • num_ensemble (int, default: 1 ) –

    The number of model ensembles to make. Typically, 1 is sufficient, but during CASP14, 8 were used

  • development (bool, default: False ) –

    Whether a smaller subset of models should be used for increased testing performance

Returns: A dictionary of the model name to the RunModel instance for each 'model_type'/'num_predictions_per_model' requested

Source code in symdesign/resources/ml.py
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
def set_up_model_runners(model_type: af_model_literal = 'monomer', number_of_residues: int = 1000, num_predictions_per_model: int = 1,
                         num_ensemble: int = 1, development: bool = False) -> dict[str, RunModel]:
    """Produce Alphafold RunModel class loaded with their training parameters

    Args:
        model_type: The type of model to load. Should be one of the viable Alphafold models including:
            'monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'
        number_of_residues: The number of residues in the model. Used only to calculate approximate memory needs during
            device allocation
        num_predictions_per_model: The number of predictions to make for each Alphafold model. Essentially duplicates
            the original models 'num_predictions_per_model' times
        num_ensemble: The number of model ensembles to make. Typically, 1 is sufficient, but during CASP14, 8 were used
        development: Whether a smaller subset of models should be used for increased testing performance
    Returns:
        A dictionary of the model name to the RunModel instance for each 'model_type'/'num_predictions_per_model'
            requested
    """
    # model_runners = {}
    # model_names = afconfig.MODEL_PRESETS[model_type]
    # for model_name in model_names:
    #     if development and model_name != 'model_2_multimer_v3':
    #         continue
    #     model_config = afconfig.model_config(model_name)
    #     if model_config.model.global_config.multimer_mode:
    #         model_config.model.num_ensemble_eval = num_ensemble
    #     else:
    #         model_config.data.eval.num_ensemble = num_ensemble
    #     model_params = afdata.get_model_haiku_params(model_name=model_name, data_dir=putils.alphafold_db_dir)
    #     # This is using prev_pos init
    #     model_runner = RunModel(model_config, model_params)
    #     # This should be used if the prediction is not for a design and we have an msa
    #     # model_runner = afmodel.RunModel(model_config, model_params)
    #
    #     for i in range(num_predictions_per_model):
    #         model_runners[f'{model_name}_pred_{i}'] = model_runner
    #
    # num_models = len(model_runners)
    # logger.info(f'Loaded {num_models} Alphafold models: {list(model_runners.keys())}')
    #
    # return model_runners

    # This routine is used to store each separate model parameters on one RunModel
    # Get model config
    model_config = afconfig.model_config(model_type_to_config_name[model_type])
    if model_config.model.global_config.multimer_mode:
        model_config.model.num_ensemble_eval = num_ensemble
    else:
        model_config.data.eval.num_ensemble = num_ensemble
    # Set up model params
    model_params = {}
    for model_name in afconfig.MODEL_PRESETS[model_type]:
        model_param = afdata.get_model_haiku_params(model_name=model_name, data_dir=putils.alphafold_db_dir)
        if 'model_1' in model_name:
            # Using the config for model_1 as it is most similar to other models
            #  model_1/2 includes template embeddings (monomer),
            #  while multimer model_1 is fairly similar to 2-5
            af_device = get_alphafold_model_device(number_of_residues)
            logger.info(f'Using device {af_device} for model calculations')
            # RunModel is using prev_pos init
            model_runner = RunModel(model_config, model_param, device=af_device)
            # # ?? Not sure why this would be the case -> if the prediction is not for a design and there is an msa
            # model_runner = afmodel.RunModel(model_config, model_params)
            if development:
                model_params[f'{model_name}_pred_{0}'] = model_param
                break

        for i in range(num_predictions_per_model):
            model_params[f'{model_name}_pred_{i}'] = model_param

    num_models = len(model_params)
    logger.info(f'Loaded {num_models} Alphafold models: {", ".join(model_params)}')

    model_runner.set_params(model_params)
    return {model_param_name: model_runner for model_param_name in model_params}

af_predict

af_predict(features: FeatureDict, model_runners: dict[str, RunModel], gpu_relax: bool = False, models_to_relax: relax_options_literal = None, random_seed: int = None, confidence_stop_threshold: float = 0.85) -> tuple[dict[str, dict[str, str]], dict[str, FeatureDict]]

Run Alphafold to predict a structure from sequence/msa/template features

Parameters:

  • # (length) –

    The length of the desired output for prediction metrics

  • features (FeatureDict) –

    The sequence/msa/template feature parameters to populate the jax model

  • model_runners (dict[str, RunModel]) –

    The RunModel instances which should predict the structure

  • gpu_relax (bool, default: False ) –

    Whether predictions should be relaxed using a GPU (if one is available)

  • models_to_relax (relax_options_literal, default: None ) –

    Specify which predictions should be relaxed

  • random_seed (int, default: None ) –

    A random integer to seed the model. Could be provided to ensure consistency across runs

  • confidence_stop_threshold (float, default: 0.85 ) –

    The confidence threshold to stop prediction if a prediction scores higher than it. Value provided should be between [0-1]. Will use mean plddt if the model is monomer, if model is multimer, will use 0.8interface_predicted_template_modeling_score + 0.2predicted_template_modeling_score

Returns: The tuple of structure and score dictionaries. Where structures contains the keys 'relaxed' and 'unrelaxed' mapped to the model name and the model PDB string and folding_scores contain the model name mapped to each of the score types 'predicted_aligned_error' (length, length), 'plddt' (length), 'predicted_template_modeling_score' (1), and 'predicted_interface_template_modeling_score' (1)

Source code in symdesign/resources/ml.py
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
def af_predict(features: FeatureDict, model_runners: dict[str, RunModel], gpu_relax: bool = False,
               models_to_relax: relax_options_literal = None, random_seed: int = None,
               confidence_stop_threshold: float = 0.85) -> tuple[dict[str, dict[str, str]], dict[str, FeatureDict]]:
    """Run Alphafold to predict a structure from sequence/msa/template features

    Args:
        # length: The length of the desired output for prediction metrics
        features: The sequence/msa/template feature parameters to populate the jax model
        model_runners: The RunModel instances which should predict the structure
        gpu_relax: Whether predictions should be relaxed using a GPU (if one is available)
        models_to_relax: Specify which predictions should be relaxed
        random_seed: A random integer to seed the model. Could be provided to ensure consistency across runs
        confidence_stop_threshold: The confidence threshold to stop prediction if a prediction scores higher than it.
            Value provided should be between [0-1]. Will use mean plddt if the model is monomer, if model is multimer,
            will use 0.8*interface_predicted_template_modeling_score + 0.2*predicted_template_modeling_score
    Returns:
        The tuple of structure and score dictionaries. Where structures contains the keys 'relaxed' and
        'unrelaxed' mapped to the model name and the model PDB string and folding_scores contain the model name
        mapped to each of the score types 'predicted_aligned_error' (length, length), 'plddt' (length),
        'predicted_template_modeling_score' (1), and 'predicted_interface_template_modeling_score' (1)
    """
    num_models = len(model_runners)
    if random_seed is None:  # Make one
        random_seed = random.randrange(sys.maxsize // num_models)

    if confidence_stop_threshold > 1:
        raise ValueError(f"confidence_stop_threshold must be between 0 and 1. If using monomer models with plddt, "
                         f"will take the confidence metric as the percent plddt (out of a maximum of 100)")

    # # Set up folding_scores dictionary
    # scores = {
    #     'predicted_aligned_error': np.zeros((num_models, length, length), dtype=np.float32),
    #     'plddt': np.zeros((num_models, length), dtype=np.float32),
    #     'predicted_template_modeling_score': np.zeros(num_models, dtype=np.float32),
    #     'predicted_interface_template_modeling_score': np.zeros(num_models, dtype=np.float32)
    # }
    for model_name, model_runner in model_runners.items():
        if model_runner.multimer_mode:
            change_scores = [('iptm', 'predicted_interface_template_modeling_score'),
                             ('ptm', 'predicted_template_modeling_score')]
            # scores_ = {'predicted_template_modeling_score': [],
            #            'predicted_interface_template_modeling_score': []}
        elif 'ptm' in model_name:
            change_scores = [('ptm', 'predicted_template_modeling_score')]
            # scores_ = {'predicted_template_modeling_score': []}
        else:
            change_scores = []
            # raise NotImplementedError()
        break
    else:  # Can't run without model_runners...
        change_scores = []
    #     scores_ = {}
    # scores = {model_name: copy.deepcopy(scores_) for model_name in model_runners}

    unneeded_scores = [
        'distogram', 'experimentally_resolved', 'masked_msa', 'predicted_lddt', 'structure_module',
        'final_atom_positions', 'num_recycles', 'aligned_confidence_probs',
        'max_predicted_aligned_error',
        # 'ranking_confidence',
        # 'ptm', 'iptm', 'predicted_aligned_error', 'plddt',
    ]
    scores = {}
    ranking_confidences = {}
    unrelaxed_proteins = {}
    unrelaxed_pdbs_ = {}
    # Run the models.
    for model_index, (model_name, model_runner) in enumerate(model_runners.items()):
        logger.info(f'Running JAX {model_name}')
        model_random_seed = model_index + random_seed*num_models
        processed_feature_dict = \
            model_runner.process_features(features, random_seed=model_random_seed)

        t_0 = time.time()
        # prediction_result = model_runner.predict(processed_feature_dict, random_seed=model_random_seed)
        prediction_result = \
            model_runner.predict_with_params(model_name, processed_feature_dict, random_seed=model_random_seed)
        logger.info(f'Prediction took {time.time() - t_0:.1f}s')
        # if this is the first go in the model_runner, then f'(includes compilation time)' would be accurate
        # Monomer?
        #  Should take about 96 secs on a 1000 residue protein using 3 recycles...

        # Remove jax dependency from results.
        np_prediction_result = jnp_to_np(dict(prediction_result))
        # logger.debug(f'Found prediction_results: {np_prediction_result}')
        # monomer
        # ['distogram', 'experimentally_resolved', 'masked_msa', 'predicted_lddt', 'structure_module', 'plddt',
        #  'ranking_confidence']
        # multimer
        # ['distogram', 'experimentally_resolved', 'masked_msa', 'predicted_lddt', 'structure_module', 'plddt',
        #  'ranking_confidence'
        #  'num_recycles', 'predicted_aligned_error', 'aligned_confidence_probs', 'max_predicted_aligned_error',
        #  'ptm', 'iptm']
        # logger.debug(f'Found the prediction_result shapes: {model_runner.eval_shape(np_prediction_result)}')
        # {'distogram': {'bin_edges': (63,), 'logits': (n_residues, n_residues, 64)},
        #  'experimentally_resolved': {'logits': (n_residues, atom_types)},
        #  'masked_msa': {'logits': (n_sequences, n_residues, n_amino_acid_types_gapped_unknown)},
        #  'predicted_aligned_error': (n_residues, n_residues),
        #  'predicted_lddt': {'logits': (n_residues, 50)},
        #  'structure_module': {'final_atom_mask': (n_residues, atom_types),
        #  'final_atom_positions': (n_residues, atom_types, 3)},
        #  'plddt': (n_residues,), 'aligned_confidence_probs': (n_residues, n_residues, 64),
        #  'max_predicted_aligned_error': (), 'ptm': (), 'iptm': (), 'ranking_confidence': (), 'num_recycles': (),
        #  }
        # Where ['predicted_lddt'] has the key ['logits'] which probably ?contains the raw logit values produced by
        # model heads? for the binned distogram rankings?

        # plddt = np_prediction_result['plddt']
        # scores[model_name]['plddt'] = plddt  # [:length]
        # if model_runner.multimer_mode:
        #     # This is a 2d array. Clean up to ASU at some point
        #     scores[model_name]['predicted_aligned_error'] = np_prediction_result['predicted_aligned_error']
        #     # scores['predicted_interface_template_modeling_score'][model_index] = np_prediction_result['iptm']
        #     scores[model_name]['predicted_interface_template_modeling_score'].append(np_prediction_result['iptm'])
        #     scores[model_name]['predicted_template_modeling_score'].append(np_prediction_result['ptm'])
        # elif 'ptm' in model_name:
        #     scores[model_name]['predicted_aligned_error'] = \
        #         np_prediction_result['predicted_aligned_error']  # [:length, :length]
        #     scores[model_name]['predicted_template_modeling_score'].append(np_prediction_result['ptm'])

        # Add the predicted LDDT in the b-factor column.
        plddt = np_prediction_result['plddt']
        # Note that higher predicted LDDT value means higher model confidence.
        plddt_b_factors = np.repeat(plddt[:, None], residue_constants.atom_type_num, axis=-1)
        unrelaxed_protein = afprotein.from_prediction(
            features=processed_feature_dict,
            result=prediction_result,
            b_factors=plddt_b_factors,
            remove_leading_feature_dimension=not model_runner.multimer_mode)
        unrelaxed_proteins[model_name] = unrelaxed_protein
        unrelaxed_pdbs_[model_name] = afprotein.to_pdb(unrelaxed_protein)

        ranking_confidences[model_name] = confidence_metric = np_prediction_result.pop('ranking_confidence')
        # Process incoming scores to be returned
        for old_score, new_score in change_scores:
            np_prediction_result[new_score] = np_prediction_result.pop(old_score)
        # Remove unnecessary scores
        for score in unneeded_scores:
            np_prediction_result.pop(score, None)
        scores[model_name] = np_prediction_result

        if model_runner.multimer_mode:
            pass
        else:  # monomer mode. Divide by 100 to get a percentage
            confidence_metric /= 100
        if confidence_metric > confidence_stop_threshold:
            # The prediction quality is already satisfactory
            break

    # Rank model names by model confidence.
    ranked_order = [design_model_name for design_model_name, confidence in
                    sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)]
    # Sort the unrelaxed_pdbs accordingly
    unrelaxed_pdbs = {name: unrelaxed_pdbs_.pop(name) for name in ranked_order}

    # Relax predictions.
    relaxed_pdbs = {}
    # relax_metrics = {}
    if models_to_relax is None:
        pass
    else:
        if models_to_relax == 'best':
            to_relax = [ranked_order[0]]
        else:  # if models_to_relax == 'all':
            to_relax = ranked_order

        logger.info(f'Starting Amber relaxation')
        t_0 = time.time()
        for model_name in to_relax:
            logger.info(f'Relaxing {model_name}')
            # relaxed_pdb_str, _, violations = amber_relaxer.process(prot=unrelaxed_proteins[model_name])
            try:
                relaxed_pdb_str, violations = amber_relax(prot=unrelaxed_proteins[model_name], gpu=gpu_relax)
            except ValueError as error:  # Minimization failed after {max_iterations} attempts.
                logger.error(f'Ran into problem during Amber relax: {error}\nSkipping {model_name}')
                continue
            else:
                # relax_metrics[model_name] = {
                #     'remaining_violations': violations,
                #     'remaining_violations_count': sum(violations)
                # }
                relaxed_pdbs[model_name] = relaxed_pdb_str

        logger.info(f'Relaxation took {time.time() - t_0:.1f}s')

    return {'relaxed': relaxed_pdbs, 'unrelaxed': unrelaxed_pdbs}, scores