|
3 | 3 | using Microsoft.Data.Sqlite;
|
4 | 4 | using Microsoft.Extensions.Configuration;
|
5 | 5 | using Microsoft.Extensions.DependencyInjection;
|
| 6 | +using Microsoft.Extensions.DependencyModel; |
6 | 7 | using Microsoft.Extensions.Diagnostics.HealthChecks;
|
| 8 | +using Microsoft.Extensions.Logging; |
| 9 | +using System.Data.Common; |
| 10 | +using System.Runtime.InteropServices; |
| 11 | +using System.Text.Json; |
| 12 | +using RuntimeEnvironment = Microsoft.DotNet.PlatformAbstractions.RuntimeEnvironment; |
7 | 13 |
|
8 | 14 | namespace Microsoft.Extensions.Hosting;
|
9 | 15 |
|
@@ -61,6 +67,15 @@ private static void AddSqliteClient(
|
61 | 67 | settings.ConnectionString = connectionString;
|
62 | 68 | }
|
63 | 69 |
|
| 70 | + if (!string.IsNullOrEmpty(settings.ConnectionString)) |
| 71 | + { |
| 72 | + var cbs = new DbConnectionStringBuilder { ConnectionString = settings.ConnectionString }; |
| 73 | + if (cbs.TryGetValue("Extensions", out var extensions)) |
| 74 | + { |
| 75 | + settings.Extensions = JsonSerializer.Deserialize<IEnumerable<SqliteExtensionMetadata>>((string)extensions) ?? []; |
| 76 | + } |
| 77 | + } |
| 78 | + |
64 | 79 | configureSettings?.Invoke(settings);
|
65 | 80 |
|
66 | 81 | builder.RegisterSqliteServices(settings, connectionName, serviceKey);
|
@@ -100,8 +115,181 @@ private static void RegisterSqliteServices(
|
100 | 115 |
|
101 | 116 | SqliteConnection CreateConnection(IServiceProvider sp, object? key)
|
102 | 117 | {
|
| 118 | + var logger = sp.GetRequiredService<ILogger<SqliteConnection>>(); |
103 | 119 | ConnectionStringValidation.ValidateConnectionString(settings.ConnectionString, connectionName, DefaultConfigSectionName);
|
104 |
| - return new SqliteConnection(settings.ConnectionString); |
| 120 | + var csb = new DbConnectionStringBuilder { ConnectionString = settings.ConnectionString }; |
| 121 | + if (csb.ContainsKey("Extensions")) |
| 122 | + { |
| 123 | + csb.Remove("Extensions"); |
| 124 | + } |
| 125 | + var connection = new SqliteConnection(csb.ConnectionString); |
| 126 | + |
| 127 | + foreach (var extension in settings.Extensions) |
| 128 | + { |
| 129 | + if (extension.IsNuGetPackage) |
| 130 | + { |
| 131 | + if (string.IsNullOrEmpty(extension.PackageName)) |
| 132 | + { |
| 133 | + throw new InvalidOperationException("PackageName is required when loading an extension from a NuGet package."); |
| 134 | + } |
| 135 | + |
| 136 | + EnsureLoadableFromNuGet(extension.Extension, extension.PackageName, logger); |
| 137 | + } |
| 138 | + else |
| 139 | + { |
| 140 | + if (string.IsNullOrEmpty(extension.ExtensionFolder)) |
| 141 | + { |
| 142 | + throw new InvalidOperationException("ExtensionFolder is required when loading an extension from a folder."); |
| 143 | + } |
| 144 | + |
| 145 | + EnsureLoadableFromLocalPath(extension.Extension, extension.ExtensionFolder); |
| 146 | + } |
| 147 | + connection.LoadExtension(extension.Extension); |
| 148 | + } |
| 149 | + |
| 150 | + return connection; |
| 151 | + } |
| 152 | + } |
| 153 | + |
| 154 | + // Adapted from https://github.com/dotnet/docs/blob/dbbeda13bf016a6ff76b0baab1488c927a64ff24/samples/snippets/standard/data/sqlite/ExtensionsSample/Program.cs#L40 |
| 155 | + internal static void EnsureLoadableFromNuGet(string package, string library, ILogger<SqliteConnection> logger) |
| 156 | + { |
| 157 | + var runtimeLibrary = DependencyContext.Default?.RuntimeLibraries.FirstOrDefault(l => l.Name == package); |
| 158 | + if (runtimeLibrary is null) |
| 159 | + { |
| 160 | + logger.LogInformation("Could not find the runtime library for package {Package}", package); |
| 161 | + return; |
| 162 | + } |
| 163 | + |
| 164 | + string sharedLibraryExtension; |
| 165 | + string pathVariableName = "PATH"; |
| 166 | + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) |
| 167 | + { |
| 168 | + sharedLibraryExtension = ".dll"; |
| 169 | + } |
| 170 | + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) |
| 171 | + { |
| 172 | + sharedLibraryExtension = ".so"; |
| 173 | + pathVariableName = "LD_LIBRARY_PATH"; |
| 174 | + } |
| 175 | + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) |
| 176 | + { |
| 177 | + sharedLibraryExtension = ".dylib"; |
| 178 | + pathVariableName = "DYLD_LIBRARY_PATH"; |
| 179 | + } |
| 180 | + else |
| 181 | + { |
| 182 | + throw new NotSupportedException("Unsupported OS platform"); |
| 183 | + } |
| 184 | + |
| 185 | + var candidateAssets = new Dictionary<(string? Package, string Asset), int>(); |
| 186 | + var rid = RuntimeEnvironment.GetRuntimeIdentifier(); |
| 187 | + var rids = DependencyContext.Default?.RuntimeGraph.First(g => g.Runtime == rid).Fallbacks.ToList() ?? []; |
| 188 | + rids.Insert(0, rid); |
| 189 | + |
| 190 | + logger.LogInformation("Looking for {Library} in {Package} runtime assets", library, package); |
| 191 | + logger.LogInformation("Possible runtime identifiers: {Rids}", string.Join(", ", rids)); |
| 192 | + |
| 193 | + foreach (var group in runtimeLibrary.NativeLibraryGroups) |
| 194 | + { |
| 195 | + foreach (var file in group.RuntimeFiles) |
| 196 | + { |
| 197 | + if (string.Equals( |
| 198 | + Path.GetFileName(file.Path), |
| 199 | + library + sharedLibraryExtension, |
| 200 | + StringComparison.OrdinalIgnoreCase)) |
| 201 | + { |
| 202 | + var fallbacks = rids.IndexOf(group.Runtime); |
| 203 | + if (fallbacks != -1) |
| 204 | + { |
| 205 | + logger.LogInformation("Found {Library} in {Package} runtime assets at {Path}", library, package, file.Path); |
| 206 | + candidateAssets.Add((runtimeLibrary.Path, file.Path), fallbacks); |
| 207 | + } |
| 208 | + } |
| 209 | + } |
| 210 | + } |
| 211 | + |
| 212 | + var assetPath = candidateAssets |
| 213 | + .OrderBy(p => p.Value) |
| 214 | + .Select(p => p.Key) |
| 215 | + .FirstOrDefault(); |
| 216 | + if (assetPath != default) |
| 217 | + { |
| 218 | + string? assetDirectory = null; |
| 219 | + if (File.Exists(Path.Combine(AppContext.BaseDirectory, assetPath.Asset))) |
| 220 | + { |
| 221 | + // NB: Framework-dependent deployments copy assets to the application base directory |
| 222 | + assetDirectory = Path.Combine( |
| 223 | + AppContext.BaseDirectory, |
| 224 | + Path.GetDirectoryName(assetPath.Asset.Replace('/', Path.DirectorySeparatorChar))!); |
| 225 | + |
| 226 | + logger.LogInformation("Found {Library} in {Package} runtime assets at {Path}", library, package, assetPath.Asset); |
| 227 | + } |
| 228 | + else |
| 229 | + { |
| 230 | + string? assetFullPath = null; |
| 231 | + var probingDirectories = ((string?)AppDomain.CurrentDomain.GetData("PROBING_DIRECTORIES"))? |
| 232 | + .Split(Path.PathSeparator) ?? []; |
| 233 | + foreach (var directory in probingDirectories) |
| 234 | + { |
| 235 | + var candidateFullPath = Path.Combine( |
| 236 | + directory, |
| 237 | + assetPath.Package ?? "", |
| 238 | + assetPath.Asset); |
| 239 | + if (File.Exists(candidateFullPath)) |
| 240 | + { |
| 241 | + assetFullPath = candidateFullPath; |
| 242 | + } |
| 243 | + } |
| 244 | + |
| 245 | + assetDirectory = Path.GetDirectoryName(assetFullPath); |
| 246 | + logger.LogInformation("Found {Library} in {Package} runtime assets at {Path} (using PROBING_DIRECTORIES: {ProbingDirectories})", library, package, assetFullPath, string.Join(",", probingDirectories)); |
| 247 | + } |
| 248 | + |
| 249 | + var path = new HashSet<string>(Environment.GetEnvironmentVariable(pathVariableName)!.Split(Path.PathSeparator)); |
| 250 | + |
| 251 | + if (assetDirectory is not null && path.Add(assetDirectory)) |
| 252 | + { |
| 253 | + logger.LogInformation("Adding {AssetDirectory} to {PathVariableName}", assetDirectory, pathVariableName); |
| 254 | + Environment.SetEnvironmentVariable(pathVariableName, string.Join(Path.PathSeparator, path)); |
| 255 | + logger.LogInformation("Set {PathVariableName} to: {PathVariableValue}", pathVariableName, Environment.GetEnvironmentVariable(pathVariableName)); |
| 256 | + } |
| 257 | + } |
| 258 | + else |
| 259 | + { |
| 260 | + logger.LogInformation("Could not find {Library} in {Package} runtime assets", library, package); |
| 261 | + } |
| 262 | + } |
| 263 | + |
| 264 | + internal static void EnsureLoadableFromLocalPath(string library, string assetDirectory) |
| 265 | + { |
| 266 | + string sharedLibraryExtension; |
| 267 | + string pathVariableName = "PATH"; |
| 268 | + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) |
| 269 | + { |
| 270 | + sharedLibraryExtension = ".dll"; |
| 271 | + } |
| 272 | + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) |
| 273 | + { |
| 274 | + sharedLibraryExtension = ".so"; |
| 275 | + pathVariableName = "LD_LIBRARY_PATH"; |
| 276 | + } |
| 277 | + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) |
| 278 | + { |
| 279 | + sharedLibraryExtension = ".dylib"; |
| 280 | + pathVariableName = "DYLD_LIBRARY_PATH"; |
| 281 | + } |
| 282 | + else |
| 283 | + { |
| 284 | + throw new NotSupportedException("Unsupported OS platform"); |
| 285 | + } |
| 286 | + |
| 287 | + if (File.Exists(Path.Combine(assetDirectory, library + sharedLibraryExtension))) |
| 288 | + { |
| 289 | + var path = new HashSet<string>(Environment.GetEnvironmentVariable(pathVariableName)!.Split(Path.PathSeparator)); |
| 290 | + |
| 291 | + if (assetDirectory is not null && path.Add(assetDirectory)) |
| 292 | + Environment.SetEnvironmentVariable(pathVariableName, string.Join(Path.PathSeparator, path)); |
105 | 293 | }
|
106 | 294 | }
|
107 | 295 | }
|
0 commit comments