diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index 6c12d4ed3d86..bf6a56f99f7a 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -1138,9 +1138,11 @@ protected void hookOnSubscribe(Subscription subscription) { protected void hookOnNext(DataBuffer dataBuffer) { try { try (DataBuffer.ByteBufferIterator iterator = dataBuffer.readableByteBuffers()) { - ByteBuffer byteBuffer = iterator.next(); - while (byteBuffer.hasRemaining()) { - this.channel.write(byteBuffer); + while (iterator.hasNext()) { + ByteBuffer byteBuffer = iterator.next(); + while (byteBuffer.hasRemaining()) { + this.channel.write(byteBuffer); + } } } this.sink.next(dataBuffer); @@ -1213,6 +1215,11 @@ protected void hookOnNext(DataBuffer dataBuffer) { failed(ex, attachment); } } + else { + iterator.close(); + this.sink.next(dataBuffer); + request(1); + } } @Override @@ -1236,7 +1243,6 @@ protected void hookOnComplete() { @Override public void completed(Integer written, Attachment attachment) { DataBuffer.ByteBufferIterator iterator = attachment.iterator(); - iterator.close(); long pos = this.position.addAndGet(written); ByteBuffer byteBuffer = attachment.byteBuffer(); @@ -1246,9 +1252,11 @@ public void completed(Integer written, Attachment attachment) { } else if (iterator.hasNext()) { ByteBuffer next = iterator.next(); - this.channel.write(next, pos, attachment, this); + Attachment nextAttachment = new Attachment(next, attachment.dataBuffer(), iterator); + this.channel.write(next, pos, nextAttachment, this); } else { + iterator.close(); this.sink.next(attachment.dataBuffer()); this.writing.set(false); diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 00d95af03a18..6c0be595624a 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -338,6 +338,27 @@ void writeWritableByteChannel(DataBufferFactory bufferFactory) throws Exception channel.close(); } + @ParameterizedDataBufferAllocatingTest + void writeWritableByteChannelWithJoinedBuffer(DataBufferFactory bufferFactory) throws Exception { + super.bufferFactory = bufferFactory; + + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + DataBuffer joined = bufferFactory.join(List.of(foo, bar)); + + WritableByteChannel channel = Files.newByteChannel(tempFile, StandardOpenOption.WRITE); + + Flux writeResult = DataBufferUtils.write(Flux.just(joined), channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foobar")) + .verifyComplete(); + + String result = String.join("", Files.readAllLines(tempFile)); + + assertThat(result).isEqualTo("foobar"); + channel.close(); + } + @ParameterizedDataBufferAllocatingTest void writeWritableByteChannelErrorInFlux(DataBufferFactory bufferFactory) throws Exception { super.bufferFactory = bufferFactory; @@ -445,6 +466,27 @@ private void verifyWrittenData(Flux writeResult) throws IOException assertThat(result).isEqualTo("foobarbazqux"); } + @ParameterizedDataBufferAllocatingTest + void writeAsynchronousFileChannelWithJoinedBuffer(DataBufferFactory bufferFactory) throws Exception { + super.bufferFactory = bufferFactory; + + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + DataBuffer joined = bufferFactory.join(List.of(foo, bar)); + + AsynchronousFileChannel channel = AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE); + + Flux writeResult = DataBufferUtils.write(Flux.just(joined), channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foobar")) + .verifyComplete(); + + String result = String.join("", Files.readAllLines(tempFile)); + + assertThat(result).isEqualTo("foobar"); + channel.close(); + } + @ParameterizedDataBufferAllocatingTest void writeAsynchronousFileChannelErrorInFlux(DataBufferFactory bufferFactory) throws Exception { super.bufferFactory = bufferFactory;