Skip to content

Conversation

jryanhaber
Copy link

Summary

Extends cross-platform compatibility by adding conditional fallbacks for
flash attention when not available on Apple Silicon MPS and other non-CUDA
systems.

Problem

The HRM codebase crashes on systems without flash_attn (primarily Apple
Silicon) due to hard dependency, preventing broader adoption and
contribution from non-CUDA users.

Solution

Added conditional fallbacks that preserve original CUDA performance while
enabling PyTorch attention fallback on incompatible systems.

Changes

Core Compatibility (models/layers.py)

  • Conditional import with graceful flash_attn fallback
  • Runtime detection: uses flash_attn when available, PyTorch attention
    otherwise
  • Zero performance impact on CUDA systems
  • Full backward compatibility

Technical Implementation

try:
    from flash_attn_interface import flash_attn_func
except ImportError:
    try:
        from flash_attn import flash_attn_func
    except ImportError:
        flash_attn_func = None

Runtime conditional execution

  if flash_attn_func is not None:
      attn_output = flash_attn_func(q=query, k=key, v=value,
  causal=self.causal)
  else:
      attn_output = F.scaled_dot_product_attention(query, key, value,
  is_causal=self.causal)

Compatibility Matrix

  • CUDA + flash_attn: ✅ Original performance (no changes)
  • CUDA - flash_attn: ✅ PyTorch fallback
  • Apple Silicon (MPS): ✅ PyTorch fallback
  • CPU: ✅ Universal compatibility

Testing

  • Device detection: MPS/CUDA/CPU verified
  • Model loading: 2.2GB checkpoint compatibility confirmed
  • Inference: Cross-platform execution validated
  • Performance: Zero impact on CUDA workflows

Impact

  • Broader hardware support for contributors
  • Maintained optimal performance on CUDA
  • Reduced onboarding friction
  • No breaking changes

This follows the established conditional device pattern while extending
compatibility to the attention layer.

This commit aims to enable the model training and evaluation processes to run seamlessly on different devices (MPS, CUDA, or CPU) by introducing a device detection mechanism.
This will take effect if either MPS or CUDA is available, otherwise, defaults to CPU, because some users may not have GPU access. This commit aims to provide flexibility and accessibility for users with different hardware configurations, so that no users are excluded. It also attempts to simplify the code by using a single device variable throughout the training and evaluation scripts, and has tests confirming this is working that test the assumptions of device availability and correct tensor placement.

To test this:
1. Run `test_device.py` to verify the correct device is detected and tensor operations are successful.
2. Run `python evaluate.py` and `python pretrain.py` and check the console output to confirm the intended device is being used.

Criticality: 5
Severity if code fails: 7
User impact if code fails: Model training and evaluation may not run or may run very slowly if the device is not correctly identified.
This commit aims to provide a foundational dataset for holon tags and a master tag list, enabling richer metadata and categorization within the system. This is intended to improve searchability, organization, and discoverability of holons. It will only take effect if the data is properly loaded and utilized by consuming applications. It also tries to establish a clear separation between holon-specific tags and a comprehensive master tag list.

To test this:
1. Verify that the `holon_tags.csv` and `tags_master.csv` files are present in the `dataset/` directory.
2. Confirm the integrity of the data within each file by examining the headers and sample rows.
3. Ensure that the tags listed in `holon_tags.csv` are a subset of those defined in `tags_master.csv`.
4. Check consuming applications properly utilize these tag datasets.

Criticality: 5
Severity if code fails: 3
User impact if code fails: Reduced searchability and discoverability of holons, potentially hindering user experience in finding relevant content.
## Summary
Extends cross-platform compatibility by adding conditional fallbacks for flash attention when not available (e.g., on Apple Silicon MPS systems).

## Changes Made

### Core Compatibility (models/layers.py)
- Add conditional import for flash_attn with graceful fallback to None
- Implement PyTorch scaled_dot_product_attention fallback when flash_attn unavailable
- Preserve original flash_attn behavior when available (zero impact on CUDA systems)
- Maintain identical performance characteristics for existing workflows

### Testing Infrastructure
- Add comprehensive compatibility test suite (test_hrm_compatibility.py)
- Add holon data classification test pipeline (test_holon_classification.py)
- Test device detection, attention fallbacks, model loading, and inference

## Technical Implementation

**Flash Attention Fallback:**
```python
# Conditional import with fallback
try:
    from flash_attn_interface import flash_attn_func
except ImportError:
    try:
        from flash_attn import flash_attn_func
    except ImportError:
        flash_attn_func = None  # Fallback to PyTorch attention

# Conditional execution in attention layer
if flash_attn_func is not None:
    attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
else:
    # PyTorch fallback for systems without flash_attn
    attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=self.causal)
```

## Compatibility Impact
- ✅ **CUDA systems**: Identical behavior, zero performance impact
- ✅ **MPS systems**: Now functional with PyTorch attention fallback
- ✅ **CPU systems**: Universal compatibility maintained
- ✅ **Existing checkpoints**: Full compatibility preserved

## Testing Results
- Device detection: ✅ PASS (MPS/CUDA/CPU)
- Flash attention fallback: ✅ PASS (conditional activation)
- Model loading: ✅ PASS (2.2GB checkpoint compatibility)
- Inference pipeline: ✅ PASS (cross-platform execution)

This implementation follows the established pattern of conditional device support while adding attention-layer compatibility for broader hardware adoption.
This commit aims to introduce a classification test script (`test_classification.py`) for the HRM using holon tags data, to validate the model's ability to correctly classify text using a defined tag taxonomy. This is crucial for ensuring the HRM system accurately categorizes and understands incoming information.

The script includes functionality to:

*   Load holon tags data (to be classified) and tags master data (classification taxonomy).
*   Prepare classification examples by combining titles and descriptions.
*   Test device compatibility using `torch`.

This implementation will only take effect if the data is loaded correctly from specified CSV files and the `torch` library is installed correctly, because of how the loading and operations are set up. The script also tries to provide a basic summary of the data and device information for easy validation. Tests confirm the data loads, the device is compatible, and sample data is shown so that you can validate it looks correct.

To test this:

1.  Ensure you have `pandas` and `torch` installed.
2.  Place `holon_tags.csv` and `tags_master.csv` in a `dataset` directory. The csv's can be empty.
3.  Run `python test_classification.py`.

Expected output:

*   Confirmation messages about data loading.
*   Device compatibility test result.
*   Summary of classification data and a sample example.
*   Available tag categories.

If the test fails:

*   An error message indicates the point of failure (e.g., data loading, device incompatibility).
*   The script will exit with a non-zero exit code.

Criticality: 7
Severity if code fails: 5
User impact if code fails: The HRM system might fail to accurately classify incoming data, leading to incorrect categorization and potentially affecting downstream processes relying on accurate tagging.
try:
examples, tags_master = prepare_classification_examples()

print(f"\\nClassification Test Data Summary:")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why \\n ? Maybe \n ?

@benman1
Copy link

benman1 commented Aug 16, 2025

@jryanhaber Great work, thank you! A README change would be good to make clear what the hardware support is and how to run it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants