ussher/pastebin.go
2020-11-13 20:52:44 -08:00

198 lines
4 KiB
Go

package ussher
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"sync"
"github.com/prometheus/common/log"
"github.com/spf13/viper"
bolt "go.etcd.io/bbolt"
"golang.org/x/crypto/ssh"
)
const DefaultBinLinkLen = 4
type pastebin struct {
sync.Mutex
db *bolt.DB
datadir string
config *Config
}
func (p *pastebin) HandleExec(appstring string, conn *ssh.ServerConn, channel ssh.Channel) error {
commands := strings.Split(appstring, " ")
if len(commands) == 1 {
return p.handleNewFile(channel)
}
switch commands[1] {
case "create":
return p.handleNewFile(channel)
case "get":
return p.getFileSSH(commands[2:], channel)
}
return nil
}
func (p *pastebin) getFileSSH(commands []string, channel ssh.Channel) error {
if len(commands) == 0 {
return errors.New("Need to provide a shortlink string or URL of file to get")
}
key := commands[0]
if u, err := url.Parse(commands[0]); err == nil {
if strings.HasPrefix(u.Scheme, "http") {
id, err := p.getIDfromURL(u)
if err != nil {
return err
}
key = id
}
}
f, err := p.getFileFromID(key)
if err != nil {
return err
}
_, err = io.Copy(channel, f)
return err
}
func (p *pastebin) getFileFromID(id string) (*os.File, error) {
var shasum []byte
err := p.db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte("pastebin"))
if err != nil {
return err
}
v := bucket.Get([]byte(id))
if v == nil {
return os.ErrNotExist
}
shasum = v
return nil
})
if err != nil {
return nil, err
}
urlenc := base64.URLEncoding.EncodeToString(shasum)
filename := filepath.Join(p.datadir, urlenc)
return os.Open(filename)
}
func (p *pastebin) getIDfromURL(u *url.URL) (string, error) {
if u == nil {
panic("nil URL")
}
components := strings.Split(u.Path, "/")
if len(components) < 2 {
return "", errors.New("Incomplete URL")
}
return components[2], nil
}
func (p *pastebin) handleNewFile(channel ssh.Channel) error {
f, err := ioutil.TempFile(p.datadir, "pastebin_*")
if err != nil {
return err
}
tempname := f.Name()
hasher := sha256.New()
tee := io.TeeReader(channel, hasher)
_, err = io.Copy(f, tee)
if err != nil {
f.Close()
return err
}
f.Close()
shasum := hasher.Sum(nil)
urlenc := base64.URLEncoding.EncodeToString(shasum)
err = os.Rename(tempname, filepath.Join(p.datadir, urlenc))
if err != nil {
return err
}
shorturlLen := DefaultBinLinkLen
err = p.db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte("pastebin"))
if err != nil {
return err
}
for {
v := bucket.Get([]byte(urlenc[:shorturlLen]))
if v == nil {
break
}
if bytes.Equal(v, shasum) {
// We already have the file! We just overwrote it.
// Return the same shortlink
break
}
shorturlLen += 1
}
return bucket.Put([]byte(urlenc[:shorturlLen]), shasum)
})
if err != nil {
return err
}
shortlink := path.Join(p.config.BaseURL, "p", urlenc[:shorturlLen])
_, err = fmt.Fprintln(channel, shortlink)
return err
}
func (p *pastebin) Close() error {
return p.db.Close()
}
func (p *pastebin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
w.WriteHeader(http.StatusBadRequest)
return
}
id, err := p.getIDfromURL(req.URL)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
f, err := p.getFileFromID(id)
if err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
_, err = io.Copy(w, f)
if err != nil {
log.Errorln(err)
}
return
}
func RegisterPastebin(config *Config, configNS *viper.Viper) error {
path, err := filepath.Abs(configNS.GetString("datadir"))
if err != nil {
return err
}
p := &pastebin{
datadir: path,
config: config,
}
db, err := bolt.Open(filepath.Join(p.datadir, "index.db"), 0600, nil)
if err != nil {
return err
}
p.db = db
config.HTTPMux.Handle("/p/", p)
config.SSHApps["pastebin"] = p
return nil
}