300 lines
7.2 KiB
Go
300 lines
7.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cilium/ebpf"
|
|
"github.com/cilium/ebpf/link"
|
|
"github.com/cilium/ebpf/perf"
|
|
"github.com/shirou/gopsutil/process"
|
|
)
|
|
|
|
// ebpf eBPFEvent
|
|
type eBPFEvent struct {
|
|
Pid uint32
|
|
Tid uint32
|
|
Op uint8 // 'R' or 'W'
|
|
Stage uint8 // 'E' or 'X'
|
|
Pad [6]byte
|
|
Bytes uint64
|
|
Comm [32]byte
|
|
Directory [128]byte
|
|
Filename [128]byte
|
|
}
|
|
|
|
func (m *Monitor) ListenToeBPFEvents() {
|
|
objs := io_traceObjects{}
|
|
if err := loadIo_traceObjects(&objs, nil); err != nil {
|
|
log.Fatalf("loading objects: %v", err)
|
|
}
|
|
|
|
// Attach both read and write kprobes
|
|
writeKprobe, err := link.Kprobe("vfs_write", objs.KprobeVfsWrite, nil)
|
|
if err != nil {
|
|
objs.Close()
|
|
log.Fatalf("attach kprobe vfs_write: %v", err)
|
|
}
|
|
defer writeKprobe.Close()
|
|
|
|
readKprobe, err := link.Kprobe("vfs_read", objs.KprobeVfsRead, nil)
|
|
if err != nil {
|
|
writeKprobe.Close()
|
|
objs.Close()
|
|
log.Fatalf("attach kprobe vfs_read: %v", err)
|
|
}
|
|
defer readKprobe.Close()
|
|
|
|
writeKRetprobe, err := link.Kretprobe("vfs_write", objs.KretprobeVfsWrite, nil)
|
|
if err != nil {
|
|
readKprobe.Close()
|
|
writeKprobe.Close()
|
|
objs.Close()
|
|
log.Fatalf("attach kretprobe vfs_write: %v", err)
|
|
}
|
|
defer writeKRetprobe.Close()
|
|
|
|
readKRetprobe, err := link.Kretprobe("vfs_read", objs.KretprobeVfsRead, nil)
|
|
if err != nil {
|
|
writeKRetprobe.Close()
|
|
readKprobe.Close()
|
|
writeKprobe.Close()
|
|
objs.Close()
|
|
log.Fatalf("attach kretprobe vfs_read: %v", err)
|
|
}
|
|
defer readKRetprobe.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Start different goroutines to track MariaDB PIDs and monitor tracking state
|
|
if err := monitorMariadbPIDsAndUpdate(ctx, m.wg, &objs); err != nil {
|
|
log.Fatalf("failed to start monitor: %v", err)
|
|
}
|
|
|
|
monitorTrackingState(ctx, m.wg, &objs, m)
|
|
|
|
// Set up perf reader
|
|
// Each event is 320 bytes, so it can hold around 800 events in buffer
|
|
rd, err := perf.NewReader(objs.Events, os.Getpagesize()*64)
|
|
if err != nil {
|
|
log.Fatalf("creating perf reader: %v", err)
|
|
}
|
|
defer rd.Close()
|
|
|
|
m.wg.Add(1)
|
|
go func() {
|
|
defer m.wg.Done()
|
|
var e eBPFEvent
|
|
|
|
for {
|
|
record, err := rd.Read()
|
|
if err != nil {
|
|
if isPerfClosedError(err) {
|
|
return
|
|
}
|
|
log.Printf("read error: %v", err)
|
|
time.Sleep(50 * time.Millisecond)
|
|
continue
|
|
}
|
|
|
|
if record.LostSamples != 0 {
|
|
log.Printf("lost %d samples", record.LostSamples)
|
|
continue
|
|
}
|
|
|
|
// Check if channel is closd
|
|
if m.ebpfEventChan == nil {
|
|
return
|
|
}
|
|
|
|
if err := binary.Read(bytes.NewReader(record.RawSample), binary.LittleEndian, &e); err != nil {
|
|
log.Printf("parsing event: %v", err)
|
|
continue
|
|
}
|
|
|
|
m.ebpfEventChan <- e
|
|
}
|
|
}()
|
|
|
|
<-m.ctx.Done()
|
|
log.Println("shutting down monitor...")
|
|
|
|
if err := rd.Close(); err != nil {
|
|
log.Printf("error closing perf reader: %v", err)
|
|
}
|
|
|
|
// cleanup BPF side
|
|
if err := objs.Close(); err != nil {
|
|
log.Printf("warning closing bpf objects: %v", err)
|
|
}
|
|
|
|
fmt.Println("Shutdown complete.")
|
|
}
|
|
|
|
func monitorMariadbPIDsAndUpdate(ctx context.Context, wg *sync.WaitGroup, objs *io_traceObjects) error {
|
|
// sanity checks
|
|
if ctx == nil {
|
|
return fmt.Errorf("nil context")
|
|
}
|
|
if objs == nil || objs.Pids == nil {
|
|
return fmt.Errorf("objs or objs.Pids is nil")
|
|
}
|
|
|
|
const pollInterval = 5 * time.Second
|
|
const mapValue = uint8(1)
|
|
|
|
// helper to collect current mariadb PIDs
|
|
fetchMariadb := func() (map[uint32]struct{}, error) {
|
|
procs, err := process.Processes()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting processes: %w", err)
|
|
}
|
|
out := make(map[uint32]struct{})
|
|
for _, p := range procs {
|
|
name, err := p.Name()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if strings.Compare(strings.ToLower(strings.TrimSpace(name)), "mariadbd") == 0 {
|
|
out[uint32(p.Pid)] = struct{}{}
|
|
}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// initial population of the eBPF map
|
|
initial, err := fetchMariadb()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for pid := range initial {
|
|
if err := objs.Pids.Update(pid, mapValue, ebpf.UpdateAny); err != nil {
|
|
log.Printf("warning: failed to add initial pid %d: %v", pid, err)
|
|
} else {
|
|
log.Printf("tracking pid %d (initial)", pid)
|
|
}
|
|
}
|
|
|
|
// maintain an in-memory set of tracked PIDs for quick reference (optional but kept here per request)
|
|
tracked := make(map[uint32]struct{}, len(initial))
|
|
for pid := range initial {
|
|
tracked[pid] = struct{}{}
|
|
}
|
|
|
|
ticker := time.NewTicker(pollInterval)
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Printf("MonitorMariadbPIDs stopping: %v", ctx.Err())
|
|
return
|
|
case <-ticker.C:
|
|
newSet, err := fetchMariadb()
|
|
if err != nil {
|
|
log.Printf("error fetching mariadb pids: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Read current keys from eBPF map to ensure we remove any keys that might not be in tracked (robustness)
|
|
existing := make(map[uint32]struct{})
|
|
var k uint32
|
|
var v uint8
|
|
it := objs.Pids.Iterate()
|
|
for it.Next(&k, &v) {
|
|
existing[k] = struct{}{}
|
|
}
|
|
if err := it.Err(); err != nil {
|
|
log.Printf("error iterating ebpf map: %v", err)
|
|
}
|
|
|
|
// Delete stale keys: those present in existing but not in newSet
|
|
for pid := range existing {
|
|
if _, keep := newSet[pid]; !keep {
|
|
if err := objs.Pids.Delete(pid); err != nil {
|
|
log.Printf("failed to delete stale pid %d: %v", pid, err)
|
|
} else {
|
|
delete(tracked, pid)
|
|
log.Printf("stopped tracking pid %d (deleted)", pid)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add new keys: those in newSet but not present in existing
|
|
for pid := range newSet {
|
|
if _, present := existing[pid]; present {
|
|
// already in map
|
|
if _, t := tracked[pid]; !t {
|
|
// ensure in-memory tracked stays consistent
|
|
tracked[pid] = struct{}{}
|
|
}
|
|
continue
|
|
}
|
|
if err := objs.Pids.Update(pid, mapValue, ebpf.UpdateAny); err != nil {
|
|
log.Printf("failed to add pid %d: %v", pid, err)
|
|
} else {
|
|
tracked[pid] = struct{}{}
|
|
log.Printf("tracking pid %d (added)", pid)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func monitorTrackingState(ctx context.Context, wg *sync.WaitGroup, objs *io_traceObjects, monitor *Monitor) {
|
|
const checkInterval = 1 * time.Second
|
|
|
|
ticker := time.NewTicker(checkInterval)
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
defer ticker.Stop()
|
|
|
|
enabled := false
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Printf("monitorTrackingState stopping: %v", ctx.Err())
|
|
return
|
|
case <-ticker.C:
|
|
shouldEnable := monitor.enableEbpfEventCollection
|
|
if shouldEnable && !enabled {
|
|
// Enable processing
|
|
val := uint8(1)
|
|
if err := objs.GlobalProcessingFlag.Update(uint32(0), val, ebpf.UpdateAny); err != nil {
|
|
log.Printf("failed to enable global processing in eBPF: %v", err)
|
|
} else {
|
|
enabled = true
|
|
log.Println("eBPF global processing ENABLED")
|
|
}
|
|
} else if !shouldEnable && enabled {
|
|
// Disable processing
|
|
val := uint8(0)
|
|
if err := objs.GlobalProcessingFlag.Update(uint32(0), val, ebpf.UpdateAny); err != nil {
|
|
log.Printf("failed to disable global processing in eBPF: %v", err)
|
|
} else {
|
|
enabled = false
|
|
log.Println("eBPF global processing DISABLED")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|