Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
golang / usr / local / go / src / os / writeto_linux_test.go
Size: Mime:
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package os_test

import (
	"bytes"
	"internal/poll"
	"io"
	"math/rand"
	"net"
	. "os"
	"strconv"
	"syscall"
	"testing"
	"time"
)

func TestSendFile(t *testing.T) {
	sizes := []int{
		1,
		42,
		1025,
		syscall.Getpagesize() + 1,
		32769,
	}
	t.Run("sendfile-to-unix", func(t *testing.T) {
		for _, size := range sizes {
			t.Run(strconv.Itoa(size), func(t *testing.T) {
				testSendFile(t, "unix", int64(size))
			})
		}
	})
	t.Run("sendfile-to-tcp", func(t *testing.T) {
		for _, size := range sizes {
			t.Run(strconv.Itoa(size), func(t *testing.T) {
				testSendFile(t, "tcp", int64(size))
			})
		}
	})
}

func testSendFile(t *testing.T, proto string, size int64) {
	dst, src, recv, data, hook := newSendFileTest(t, proto, size)

	// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
	n, err := io.Copy(dst, src)
	if err != nil {
		t.Fatalf("io.Copy error: %v", err)
	}

	// We should have called poll.Splice with the right file descriptor arguments.
	if n > 0 && !hook.called {
		t.Fatal("expected to called poll.SendFile")
	}
	if hook.called && hook.srcfd != int(src.Fd()) {
		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
	}
	sc, ok := dst.(syscall.Conn)
	if !ok {
		t.Fatalf("destination is not a syscall.Conn")
	}
	rc, err := sc.SyscallConn()
	if err != nil {
		t.Fatalf("destination SyscallConn error: %v", err)
	}
	if err = rc.Control(func(fd uintptr) {
		if hook.called && hook.dstfd != int(fd) {
			t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
		}
	}); err != nil {
		t.Fatalf("destination Conn Control error: %v", err)
	}

	// Verify the data size and content.
	dataSize := len(data)
	dstData := make([]byte, dataSize)
	m, err := io.ReadFull(recv, dstData)
	if err != nil {
		t.Fatalf("server Conn Read error: %v", err)
	}
	if n != int64(dataSize) {
		t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
	}
	if m != dataSize {
		t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
	}
	if !bytes.Equal(dstData, data) {
		t.Errorf("data mismatch, got %s, want %s", dstData, data)
	}
}

// newSendFileTest initializes a new test for sendfile.
//
// It creates source file and destination sockets, and populates the source file
// with random data of the specified size. It also hooks package os' call
// to poll.Sendfile and returns the hook so it can be inspected.
func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
	t.Helper()

	hook := hookSendFile(t)

	client, server := createSocketPair(t, proto)
	tempFile, data := createTempFile(t, size)

	return client, tempFile, server, data, hook
}

func hookSendFile(t *testing.T) *sendFileHook {
	h := new(sendFileHook)
	h.install()
	t.Cleanup(h.uninstall)
	return h
}

type sendFileHook struct {
	called bool
	dstfd  int
	srcfd  int
	remain int64

	written int64
	handled bool
	err     error

	original func(dst *poll.FD, src int, remain int64) (int64, error, bool)
}

func (h *sendFileHook) install() {
	h.original = *PollSendFile
	*PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) {
		h.called = true
		h.dstfd = dst.Sysfd
		h.srcfd = src
		h.remain = remain
		h.written, h.err, h.handled = h.original(dst, src, remain)
		return h.written, h.err, h.handled
	}
}

func (h *sendFileHook) uninstall() {
	*PollSendFile = h.original
}

func createTempFile(t *testing.T, size int64) (*File, []byte) {
	f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
	if err != nil {
		t.Fatalf("failed to create temporary file: %v", err)
	}
	t.Cleanup(func() {
		f.Close()
	})

	randSeed := time.Now().Unix()
	t.Logf("random data seed: %d\n", randSeed)
	prng := rand.New(rand.NewSource(randSeed))
	data := make([]byte, size)
	prng.Read(data)
	if _, err := f.Write(data); err != nil {
		t.Fatalf("failed to create and feed the file: %v", err)
	}
	if err := f.Sync(); err != nil {
		t.Fatalf("failed to save the file: %v", err)
	}
	if _, err := f.Seek(0, io.SeekStart); err != nil {
		t.Fatalf("failed to rewind the file: %v", err)
	}

	return f, data
}