-
Notifications
You must be signed in to change notification settings - Fork 13
Updating Healpix CUDA primitive #290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ASKabalan
wants to merge
23
commits into
main
Choose a base branch
from
ASKabalan
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
2fd3c8a
Update JAX Binding to use FFI
ASKabalan 2b591ca
Update JAX Primitive to accept is_linear
ASKabalan 8fe86c2
Update healpix_ffts to use new FFI lowered cuda healpix ffts
ASKabalan 933ac2a
Update benchmarks
ASKabalan e2cc68c
Update Pyproject.toml and build to include FFI headers
ASKabalan b5cbeac
Implement VMAP and transpose rules for cuda primitive
ASKabalan 9e0f121
Update JAX binding layer
ASKabalan 92fe6a0
add vmap jacrev and jacfwd tests
ASKabalan a70b262
Fix build without CUDA NVCC
ASKabalan 0e03787
Implement requested changes
ASKabalan 6f6c07e
Update tests/test_healpix_ffts.py
ASKabalan f8a9a6d
Merge remote-tracking branch 'origin/main' into ASKabalan
ASKabalan 866d1f2
don't include ffi headers if cuda is not available
ASKabalan a83dbd1
Fix memory illegal access issue
ASKabalan fd7860e
remove strict requirement on JAX being less than 0.6.0
ASKabalan d29af9b
format
ASKabalan 1ac3541
removubg s2fft callbacks
ASKabalan b75c0ce
code works
ASKabalan 00b169c
Updating CUDA extension and removing CUFFT callbacks
ASKabalan 9775bba
remvove callback params workspace
ASKabalan fb8d0df
format
ASKabalan 1c1361f
Update CMakeLists.txt
ASKabalan c27dc7e
Update CMakeLists.txt
ASKabalan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
|
||
/** | ||
* @file cudastreamhandler.hpp | ||
* @brief Singleton class for managing CUDA streams and events. | ||
* | ||
* This header provides a singleton implementation that encapsulates the creation, | ||
* management, and cleanup of CUDA streams and events. It offers functions to fork | ||
* streams, add new streams, and synchronize (join) streams with a given dependency. | ||
* | ||
* Usage example: | ||
* @code | ||
* #include "cudastreamhandler.hpp" | ||
* | ||
* int main() { | ||
* // Create a handler instance | ||
* CudaStreamHandler handler; | ||
* | ||
* // Fork 4 streams dependent on a given stream 'stream_main' | ||
* handler.Fork(stream_main, 4); | ||
* | ||
* // Do work on the forked streams... | ||
* | ||
* // Join the streams back to 'stream_main' | ||
* handler.join(stream_main); | ||
* | ||
* return 0; | ||
* } | ||
* @endcode | ||
* | ||
* Author: Wassim KABALAN | ||
*/ | ||
|
||
#ifndef CUDASTREAMHANDLER_HPP | ||
#define CUDASTREAMHANDLER_HPP | ||
|
||
#include <algorithm> | ||
#include <atomic> | ||
#include <cuda_runtime.h> | ||
#include <stdexcept> | ||
#include <thread> | ||
#include <vector> | ||
|
||
// Singleton class managing CUDA streams and events | ||
class CudaStreamHandlerImpl { | ||
public: | ||
static CudaStreamHandlerImpl &instance() { | ||
static CudaStreamHandlerImpl instance; | ||
return instance; | ||
} | ||
|
||
void AddStreams(int numStreams) { | ||
if (numStreams > m_streams.size()) { | ||
int streamsToAdd = numStreams - m_streams.size(); | ||
m_streams.resize(numStreams); | ||
std::generate(m_streams.end() - streamsToAdd, m_streams.end(), []() { | ||
cudaStream_t stream; | ||
cudaStreamCreate(&stream); | ||
return stream; | ||
}); | ||
} | ||
} | ||
|
||
void join(cudaStream_t finalStream) { | ||
std::for_each(m_streams.begin(), m_streams.end(), [this, finalStream](cudaStream_t stream) { | ||
cudaEvent_t event; | ||
cudaEventCreate(&event); | ||
cudaEventRecord(event, stream); | ||
cudaStreamWaitEvent(finalStream, event, 0); | ||
m_events.push_back(event); | ||
}); | ||
|
||
if (!cleanup_thread.joinable()) { | ||
stop_thread.store(false); | ||
cleanup_thread = std::thread([this]() { this->AsyncEventCleanup(); }); | ||
} | ||
} | ||
|
||
// Fork function to add streams and set dependency on a given stream | ||
void Fork(cudaStream_t dependentStream, int N) { | ||
AddStreams(N); // Add N streams | ||
|
||
// Set dependency on the provided stream | ||
std::for_each(m_streams.end() - N, m_streams.end(), [this, dependentStream](cudaStream_t stream) { | ||
cudaEvent_t event; | ||
cudaEventCreate(&event); | ||
cudaEventRecord(event, dependentStream); | ||
cudaStreamWaitEvent(stream, event, 0); // Set the stream to wait on the event | ||
m_events.push_back(event); | ||
}); | ||
} | ||
|
||
auto getIterator() { return StreamIterator(m_streams.begin(), m_streams.end()); } | ||
|
||
~CudaStreamHandlerImpl() { | ||
stop_thread.store(true); | ||
if (cleanup_thread.joinable()) { | ||
cleanup_thread.join(); | ||
} | ||
|
||
std::for_each(m_streams.begin(), m_streams.end(), cudaStreamDestroy); | ||
std::for_each(m_events.begin(), m_events.end(), cudaEventDestroy); | ||
} | ||
|
||
// Custom Iterator class to iterate over streams | ||
class StreamIterator { | ||
public: | ||
StreamIterator(std::vector<cudaStream_t>::iterator begin, std::vector<cudaStream_t>::iterator end) | ||
: current(begin), end(end) {} | ||
|
||
cudaStream_t next() { | ||
if (current == end) { | ||
throw std::out_of_range("No more streams."); | ||
} | ||
return *current++; | ||
} | ||
|
||
bool hasNext() const { return current != end; } | ||
|
||
private: | ||
std::vector<cudaStream_t>::iterator current; | ||
std::vector<cudaStream_t>::iterator end; | ||
}; | ||
|
||
private: | ||
CudaStreamHandlerImpl() : stop_thread(false) {} | ||
CudaStreamHandlerImpl(const CudaStreamHandlerImpl &) = delete; | ||
CudaStreamHandlerImpl &operator=(const CudaStreamHandlerImpl &) = delete; | ||
|
||
void AsyncEventCleanup() { | ||
while (!stop_thread.load()) { | ||
std::for_each(m_events.begin(), m_events.end(), [this](cudaEvent_t &event) { | ||
if (cudaEventQuery(event) == cudaSuccess) { | ||
cudaEventDestroy(event); | ||
event = nullptr; | ||
} | ||
}); | ||
std::this_thread::sleep_for(std::chrono::milliseconds(10)); | ||
} | ||
} | ||
|
||
std::vector<cudaStream_t> m_streams; | ||
std::vector<cudaEvent_t> m_events; | ||
std::thread cleanup_thread; | ||
std::atomic<bool> stop_thread; | ||
}; | ||
|
||
// Public class for encapsulating the singleton operations | ||
class CudaStreamHandler { | ||
public: | ||
CudaStreamHandler() = default; | ||
~CudaStreamHandler() = default; | ||
|
||
void AddStreams(int numStreams) { CudaStreamHandlerImpl::instance().AddStreams(numStreams); } | ||
|
||
void join(cudaStream_t finalStream) { CudaStreamHandlerImpl::instance().join(finalStream); } | ||
|
||
void Fork(cudaStream_t cudastream, int N) { CudaStreamHandlerImpl::instance().Fork(cudastream, N); } | ||
|
||
// Get the custom iterator for CUDA streams | ||
CudaStreamHandlerImpl::StreamIterator getIterator() { | ||
return CudaStreamHandlerImpl::instance().getIterator(); | ||
} | ||
}; | ||
|
||
#endif // CUDASTREAMHANDLER_HPP |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With this file removed we can remove the comment in README Lines 350 to 352 in d77e9cb
|
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.