diff --git a/src/Renci.SshNet/SftpClient.cs b/src/Renci.SshNet/SftpClient.cs index 960c6261f..2837153e2 100644 --- a/src/Renci.SshNet/SftpClient.cs +++ b/src/Renci.SshNet/SftpClient.cs @@ -905,7 +905,7 @@ public void DownloadFile(string path, Stream output, Action? downloadCall if (downloadCallback != null) { - downloadProgress = new Progress(r => downloadCallback(r.TotalBytesDownloaded)); + downloadProgress = new SynchronousProgress(r => downloadCallback(r.TotalBytesDownloaded)); } InternalDownloadFile( @@ -934,7 +934,7 @@ public Task DownloadFileAsync(string path, Stream output, IProgress(r => downloadCallback(r.TotalBytesDownloaded)); + // The System.Progress ctor captures the current synchronization context + // and posts the progress reports to it. For back-compat with previous + // versions which always posted the callback to the threadpool regardless of + // sync context, we use a custom IProgress impl. + downloadProgress = new ThreadPoolProgress(r => downloadCallback(r.TotalBytesDownloaded)); } var asyncResult = new SftpDownloadAsyncResult(asyncCallback, state); @@ -1089,7 +1093,7 @@ public void UploadFile(Stream input, string path, bool canOverride, Action(r => uploadCallback(r.TotalBytesUploaded)); + uploadProgress = new SynchronousProgress(r => uploadCallback(r.TotalBytesUploaded)); } InternalUploadFile( @@ -1273,7 +1277,11 @@ public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride, if (uploadCallback != null) { - uploadProgress = new Progress(r => uploadCallback(r.TotalBytesUploaded)); + // The System.Progress ctor captures the current synchronization context + // and posts the progress reports to it. For back-compat with previous + // versions which always posted the callback to the threadpool regardless of + // sync context, we use a custom IProgress impl. + uploadProgress = new ThreadPoolProgress(r => uploadCallback(r.TotalBytesUploaded)); } var asyncResult = new SftpUploadAsyncResult(asyncCallback, state); @@ -2417,16 +2425,10 @@ private async Task InternalDownloadFile( asyncResult?.Update(totalBytesRead); - if (downloadProgress is not null) + downloadProgress?.Report(new DownloadFileProgressReport() { - // Copy offset to ensure it's not modified between now and execution of callback - var report = new DownloadFileProgressReport() - { - TotalBytesDownloaded = totalBytesRead, - }; - - downloadProgress.Report(report); - } + TotalBytesDownloaded = totalBytesRead + }); } } finally @@ -2536,16 +2538,10 @@ private async Task InternalUploadFile( asyncResult?.Update(writtenBytes); - // Call callback to report number of bytes written - if (uploadProgress is not null) + uploadProgress?.Report(new UploadFileProgressReport() { - UploadFileProgressReport report = new() - { - TotalBytesUploaded = writtenBytes, - }; - - uploadProgress.Report(report); - } + TotalBytesUploaded = writtenBytes + }); } finally { @@ -2652,5 +2648,48 @@ private ISftpSession CreateAndConnectToSftpSession() throw; } } + + /// + /// An implementation that posts callbacks to the threadpool. + /// + private sealed class ThreadPoolProgress : IProgress + { + private readonly Action _handler; + + public ThreadPoolProgress(Action handler) + { + Debug.Assert(handler != null); + _handler = handler!; + } + + void IProgress.Report(T value) + { + _ = ThreadPool.QueueUserWorkItem(static state => + { + var (handler, value) = ((Action, T))state!; + handler(value); + }, + (_handler, value)); + } + } + + /// + /// An implementation that invokes callbacks synchronously. + /// + private sealed class SynchronousProgress : IProgress + { + private readonly Action _handler; + + public SynchronousProgress(Action handler) + { + Debug.Assert(handler != null); + _handler = handler!; + } + + void IProgress.Report(T value) + { + _handler.Invoke(value); + } + } } }