Skip to content

alpha-CROWN method compute_bound fails on broadcasting operators #95

@elfman2

Description

@elfman2

Describe the bug
The test model fails with alpha-CROWN method, but passed with CROWN.
In the following model class: 4 ways (3 in comments) to perform broadcast which all fail.

To Reproduce

import torch
from auto_LiRPA import BoundedModule, BoundedTensor
from auto_LiRPA.perturbations import PerturbationLpNorm
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
#        x = x + torch.zeros(1,2,1)
#        x.repeat(1,2,1)
#        x = x*torch.tensor([[[1.,0.],[1.,1.]]])
        x = x.transpose(2,1).matmul(torch.full((1,1,2),1.))
        return x
bm = BoundedModule(MyModel(), torch.empty(1,1,2))
bt = BoundedTensor(torch.empty((1,1,2)), 
                         PerturbationLpNorm(norm = float("inf"), 
                                            x_L=torch.full((1,1,2),0.), 
                                            x_U=torch.full((1,1,2),1.)))
bm.compute_bounds(x=bt, method='alpha-CROWN')
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[19], line 16
     11 bm = BoundedModule(MyModel(), torch.empty(1,1,2))
     12 bt = BoundedTensor(torch.empty((1,1,2)), 
     13                          PerturbationLpNorm(norm = float("inf"), 
     14                                             x_L=torch.full((1,1,2),0.), 
     15                                             x_U=torch.full((1,1,2),1.)))
---> 16 bm.compute_bounds(x=bt, method='alpha-CROWN')

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/bound_general.py:1384, in BoundedModule.compute_bounds(self, x, aux, C, method, IBP, forward, bound_lower, bound_upper, reuse_ibp, reuse_alpha, return_A, needed_A_dict, final_node_name, average_A, interm_bounds, reference_bounds, intermediate_constr, alpha_idx, aux_reference_bounds, need_A_only, cutter, decision_thresh, update_mask, ibp_nodes, cache_bounds)
   1377 kwargs = dict(x=x, C=C, method=method, interm_bounds=interm_bounds,
   1378     reference_bounds=reference_bounds, return_A=return_A,
   1379     aux_reference_bounds=aux_reference_bounds,
   1380     needed_A_dict=needed_A_dict,
   1381     final_node_name=final_node_name,
   1382     cutter=cutter, decision_thresh=decision_thresh)
   1383 if bound_upper:
-> 1384     ret2 = self._get_optimized_bounds(bound_side='upper', **kwargs)
   1385 else:
   1386     ret2 = None

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/optimized_bounds.py:622, in _get_optimized_bounds(self, x, aux, C, IBP, forward, method, bound_side, reuse_ibp, return_A, average_A, final_node_name, interm_bounds, reference_bounds, aux_reference_bounds, needed_A_dict, cutter, decision_thresh, epsilon_over_decision_thresh)
    620 if keep_best:
    621     if best_ret_u is not None:
--> 622         best_ret_u, best_ret, need_update, idx_mask, improved_idx = _update_best_ret(
    623             full_ret_u, best_ret_u, full_ret, best_ret, need_update,
    624             loss_reduction_func, idx=1, deterministic=deterministic)
    625     if best_ret_l is not None:
    626         best_ret_l, best_ret, need_update, idx_mask, improved_idx = _update_best_ret(
    627             full_ret_l, best_ret_l, full_ret, best_ret, need_update,
    628             loss_reduction_func, idx=0, deterministic=deterministic)

File /opt/app-root/lib64/python3.11/site-packages/auto_LiRPA/optimized_bounds.py:231, in _update_best_ret(full_ret_bound, best_ret_bound, full_ret, best_ret, need_update, loss_reduction_func, idx, deterministic)
    228 compare = torch.max if idx == 0 else torch.min
    229 if not deterministic:
    230     best_ret_bound[improved_idx] = compare(
--> 231         full_ret_bound[improved_idx], best_ret_bound[improved_idx])
    232 else:
    233     best_ret_bound[improved_idx] = full_ret_bound[improved_idx]

IndexError: index 1 is out of bounds for dimension 0 with size 1

System configuration:

  • OS: redhat
  • Python version: 3.11
  • Pytorch Version: torch-2.3.1
  • Hardware: CPU
  • Yes

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions