Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
fd33b43f31
|
|||
be4de07eb8
|
|||
36ed474bfe
|
|||
fdc590c8c3
|
@@ -1,8 +1,8 @@
|
|||||||
package cmdext
|
package cmdext
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"gogs.mikescher.com/BlackForestBytes/goext/mathext"
|
||||||
"io"
|
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -30,119 +30,41 @@ func run(opt CommandRunner) (CommandResult, error) {
|
|||||||
return CommandResult{}, err
|
return CommandResult{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
preader := pipeReader{
|
||||||
|
stdout: stdoutPipe,
|
||||||
|
stderr: stderrPipe,
|
||||||
|
}
|
||||||
|
|
||||||
err = cmd.Start()
|
err = cmd.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return CommandResult{}, err
|
return CommandResult{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
errch := make(chan error, 3)
|
type resultObj struct {
|
||||||
go func() { errch <- cmd.Wait() }()
|
stdout string
|
||||||
|
stderr string
|
||||||
|
stdcombined string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
// [1] read raw stdout
|
outputChan := make(chan resultObj)
|
||||||
|
|
||||||
stdoutBufferReader, stdoutBufferWriter := io.Pipe()
|
|
||||||
stdout := ""
|
|
||||||
go func() {
|
go func() {
|
||||||
buf := make([]byte, 128)
|
// we need to first fully read the pipes and then call Wait
|
||||||
for true {
|
// see https://pkg.go.dev/os/exec#Cmd.StdoutPipe
|
||||||
n, out := stdoutPipe.Read(buf)
|
|
||||||
|
|
||||||
if n > 0 {
|
stdout, stderr, stdcombined, err := preader.Read(opt.listener)
|
||||||
txt := string(buf[:n])
|
if err != nil {
|
||||||
stdout += txt
|
outputChan <- resultObj{stdout, stderr, stdcombined, err}
|
||||||
_, _ = stdoutBufferWriter.Write(buf[:n])
|
|
||||||
for _, lstr := range opt.listener {
|
|
||||||
lstr.ReadRawStdout(buf[:n])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if out == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if out != nil {
|
|
||||||
errch <- out
|
|
||||||
_ = cmd.Process.Kill()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_ = stdoutBufferWriter.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// [2] read raw stderr
|
err = cmd.Wait()
|
||||||
|
if err != nil {
|
||||||
stderrBufferReader, stderrBufferWriter := io.Pipe()
|
outputChan <- resultObj{stdout, stderr, stdcombined, err}
|
||||||
stderr := ""
|
|
||||||
go func() {
|
|
||||||
buf := make([]byte, 128)
|
|
||||||
for true {
|
|
||||||
n, err := stderrPipe.Read(buf)
|
|
||||||
|
|
||||||
if n > 0 {
|
|
||||||
txt := string(buf[:n])
|
|
||||||
stderr += txt
|
|
||||||
_, _ = stderrBufferWriter.Write(buf[:n])
|
|
||||||
for _, lstr := range opt.listener {
|
|
||||||
lstr.ReadRawStderr(buf[:n])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
errch <- err
|
|
||||||
_ = cmd.Process.Kill()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_ = stderrBufferWriter.Close()
|
|
||||||
|
outputChan <- resultObj{stdout, stderr, stdcombined, nil}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
combch := make(chan string, 32)
|
|
||||||
stopCombch := make(chan bool)
|
|
||||||
|
|
||||||
// [3] collect stdout line-by-line
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
scanner := bufio.NewScanner(stdoutBufferReader)
|
|
||||||
for scanner.Scan() {
|
|
||||||
txt := scanner.Text()
|
|
||||||
for _, lstr := range opt.listener {
|
|
||||||
lstr.ReadStdoutLine(txt)
|
|
||||||
}
|
|
||||||
combch <- txt
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// [4] collect stderr line-by-line
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
scanner := bufio.NewScanner(stderrBufferReader)
|
|
||||||
for scanner.Scan() {
|
|
||||||
txt := scanner.Text()
|
|
||||||
for _, lstr := range opt.listener {
|
|
||||||
lstr.ReadStderrLine(txt)
|
|
||||||
}
|
|
||||||
combch <- txt
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer func() { stopCombch <- true }()
|
|
||||||
|
|
||||||
// [5] combine stdcombined
|
|
||||||
|
|
||||||
stdcombined := ""
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case txt := <-combch:
|
|
||||||
stdcombined += txt + "\n" // this comes from bufio.Scanner and has no newlines...
|
|
||||||
case <-stopCombch:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// [6] run
|
|
||||||
|
|
||||||
var timeoutChan <-chan time.Time = make(chan time.Time, 1)
|
var timeoutChan <-chan time.Time = make(chan time.Time, 1)
|
||||||
if opt.timeout != nil {
|
if opt.timeout != nil {
|
||||||
timeoutChan = time.After(*opt.timeout)
|
timeoutChan = time.After(*opt.timeout)
|
||||||
@@ -155,24 +77,37 @@ func run(opt CommandRunner) (CommandResult, error) {
|
|||||||
for _, lstr := range opt.listener {
|
for _, lstr := range opt.listener {
|
||||||
lstr.Timeout()
|
lstr.Timeout()
|
||||||
}
|
}
|
||||||
return CommandResult{
|
|
||||||
StdOut: stdout,
|
|
||||||
StdErr: stderr,
|
|
||||||
StdCombined: stdcombined,
|
|
||||||
ExitCode: -1,
|
|
||||||
CommandTimedOut: true,
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
case err := <-errch:
|
if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, mathext.Min(32*time.Millisecond, *opt.timeout)); ok {
|
||||||
if exiterr, ok := err.(*exec.ExitError); ok {
|
// most of the time the cmd.Process.Kill() should also ahve finished the pipereader
|
||||||
|
// and we can at least return the already collected stdout, stderr, etc
|
||||||
|
return CommandResult{
|
||||||
|
StdOut: fallback.stdout,
|
||||||
|
StdErr: fallback.stderr,
|
||||||
|
StdCombined: fallback.stdcombined,
|
||||||
|
ExitCode: -1,
|
||||||
|
CommandTimedOut: true,
|
||||||
|
}, nil
|
||||||
|
} else {
|
||||||
|
return CommandResult{
|
||||||
|
StdOut: "",
|
||||||
|
StdErr: "",
|
||||||
|
StdCombined: "",
|
||||||
|
ExitCode: -1,
|
||||||
|
CommandTimedOut: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case outobj := <-outputChan:
|
||||||
|
if exiterr, ok := outobj.err.(*exec.ExitError); ok {
|
||||||
excode := exiterr.ExitCode()
|
excode := exiterr.ExitCode()
|
||||||
for _, lstr := range opt.listener {
|
for _, lstr := range opt.listener {
|
||||||
lstr.Finished(excode)
|
lstr.Finished(excode)
|
||||||
}
|
}
|
||||||
return CommandResult{
|
return CommandResult{
|
||||||
StdOut: stdout,
|
StdOut: outobj.stdout,
|
||||||
StdErr: stderr,
|
StdErr: outobj.stderr,
|
||||||
StdCombined: stdcombined,
|
StdCombined: outobj.stdcombined,
|
||||||
ExitCode: excode,
|
ExitCode: excode,
|
||||||
CommandTimedOut: false,
|
CommandTimedOut: false,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -183,9 +118,9 @@ func run(opt CommandRunner) (CommandResult, error) {
|
|||||||
lstr.Finished(0)
|
lstr.Finished(0)
|
||||||
}
|
}
|
||||||
return CommandResult{
|
return CommandResult{
|
||||||
StdOut: stdout,
|
StdOut: outobj.stdout,
|
||||||
StdErr: stderr,
|
StdErr: outobj.stderr,
|
||||||
StdCombined: stdcombined,
|
StdCombined: outobj.stdcombined,
|
||||||
ExitCode: 0,
|
ExitCode: 0,
|
||||||
CommandTimedOut: false,
|
CommandTimedOut: false,
|
||||||
}, nil
|
}, nil
|
||||||
|
@@ -1,6 +1,10 @@
|
|||||||
package cmdext
|
package cmdext
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
func TestStdout(t *testing.T) {
|
func TestStdout(t *testing.T) {
|
||||||
|
|
||||||
@@ -57,3 +61,166 @@ func TestStdcombined(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPartialRead(t *testing.T) {
|
||||||
|
res1, err := Runner("python").
|
||||||
|
Arg("-c").
|
||||||
|
Arg("import sys; import time; print(\"first message\", flush=True); time.sleep(5); print(\"cant see me\", flush=True);").
|
||||||
|
Timeout(100 * time.Millisecond).
|
||||||
|
Run()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
if !res1.CommandTimedOut {
|
||||||
|
t.Errorf("!CommandTimedOut")
|
||||||
|
}
|
||||||
|
if res1.StdErr != "" {
|
||||||
|
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
|
||||||
|
}
|
||||||
|
if res1.StdOut != "first message\n" {
|
||||||
|
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
|
||||||
|
}
|
||||||
|
if res1.StdCombined != "first message\n" {
|
||||||
|
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPartialReadStderr(t *testing.T) {
|
||||||
|
res1, err := Runner("python").
|
||||||
|
Arg("-c").
|
||||||
|
Arg("import sys; import time; print(\"first message\", file=sys.stderr, flush=True); time.sleep(5); print(\"cant see me\", file=sys.stderr, flush=True);").
|
||||||
|
Timeout(100 * time.Millisecond).
|
||||||
|
Run()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
if !res1.CommandTimedOut {
|
||||||
|
t.Errorf("!CommandTimedOut")
|
||||||
|
}
|
||||||
|
if res1.StdErr != "first message\n" {
|
||||||
|
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
|
||||||
|
}
|
||||||
|
if res1.StdOut != "" {
|
||||||
|
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
|
||||||
|
}
|
||||||
|
if res1.StdCombined != "first message\n" {
|
||||||
|
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadUnflushedStdout(t *testing.T) {
|
||||||
|
|
||||||
|
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stdout, end='')").Run()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
if res1.StdErr != "" {
|
||||||
|
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
|
||||||
|
}
|
||||||
|
if res1.StdOut != "message101" {
|
||||||
|
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
|
||||||
|
}
|
||||||
|
if res1.StdCombined != "message101\n" {
|
||||||
|
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadUnflushedStderr(t *testing.T) {
|
||||||
|
|
||||||
|
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stderr, end='')").Run()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
if res1.StdErr != "message101" {
|
||||||
|
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
|
||||||
|
}
|
||||||
|
if res1.StdOut != "" {
|
||||||
|
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
|
||||||
|
}
|
||||||
|
if res1.StdCombined != "message101\n" {
|
||||||
|
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPartialReadUnflushed(t *testing.T) {
|
||||||
|
t.SkipNow()
|
||||||
|
|
||||||
|
res1, err := Runner("python").
|
||||||
|
Arg("-c").
|
||||||
|
Arg("import sys; import time; print(\"first message\", end=''); time.sleep(5); print(\"cant see me\", end='');").
|
||||||
|
Timeout(100 * time.Millisecond).
|
||||||
|
Run()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
if !res1.CommandTimedOut {
|
||||||
|
t.Errorf("!CommandTimedOut")
|
||||||
|
}
|
||||||
|
if res1.StdErr != "" {
|
||||||
|
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
|
||||||
|
}
|
||||||
|
if res1.StdOut != "first message" {
|
||||||
|
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
|
||||||
|
}
|
||||||
|
if res1.StdCombined != "first message" {
|
||||||
|
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPartialReadUnflushedStderr(t *testing.T) {
|
||||||
|
t.SkipNow()
|
||||||
|
|
||||||
|
res1, err := Runner("python").
|
||||||
|
Arg("-c").
|
||||||
|
Arg("import sys; import time; print(\"first message\", file=sys.stderr, end=''); time.sleep(5); print(\"cant see me\", file=sys.stderr, end='');").
|
||||||
|
Timeout(100 * time.Millisecond).
|
||||||
|
Run()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
if !res1.CommandTimedOut {
|
||||||
|
t.Errorf("!CommandTimedOut")
|
||||||
|
}
|
||||||
|
if res1.StdErr != "first message" {
|
||||||
|
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
|
||||||
|
}
|
||||||
|
if res1.StdOut != "" {
|
||||||
|
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
|
||||||
|
}
|
||||||
|
if res1.StdCombined != "first message" {
|
||||||
|
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListener(t *testing.T) {
|
||||||
|
|
||||||
|
_, err := Runner("python").
|
||||||
|
Arg("-c").
|
||||||
|
Arg("import sys;" +
|
||||||
|
"import time;" +
|
||||||
|
"print(\"message 1\", flush=True);" +
|
||||||
|
"time.sleep(1);" +
|
||||||
|
"print(\"message 2\", flush=True);" +
|
||||||
|
"time.sleep(1);" +
|
||||||
|
"print(\"message 3\", flush=True);" +
|
||||||
|
"time.sleep(1);" +
|
||||||
|
"print(\"message 4\", file=sys.stderr, flush=True);" +
|
||||||
|
"time.sleep(1);" +
|
||||||
|
"print(\"message 5\", flush=True);" +
|
||||||
|
"time.sleep(1);" +
|
||||||
|
"print(\"final\");").
|
||||||
|
ListenStdout(func(s string) { fmt.Printf("@@STDOUT <<- %v (%v)\n", s, time.Now().Format(time.RFC3339Nano)) }).
|
||||||
|
ListenStderr(func(s string) { fmt.Printf("@@STDERR <<- %v (%v)\n", s, time.Now().Format(time.RFC3339Nano)) }).
|
||||||
|
Timeout(10 * time.Second).
|
||||||
|
Run()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
146
cmdext/pipereader.go
Normal file
146
cmdext/pipereader.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package cmdext
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pipeReader struct {
|
||||||
|
stdout io.ReadCloser
|
||||||
|
stderr io.ReadCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read ready stdout and stdin until finished
|
||||||
|
// also splits both pipes into lines and calld the listener
|
||||||
|
func (pr *pipeReader) Read(listener []CommandListener) (string, string, string, error) {
|
||||||
|
type combevt struct {
|
||||||
|
line string
|
||||||
|
stop bool
|
||||||
|
}
|
||||||
|
|
||||||
|
errch := make(chan error, 8)
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
|
||||||
|
// [1] read raw stdout
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
stdoutBufferReader, stdoutBufferWriter := io.Pipe()
|
||||||
|
stdout := ""
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 128)
|
||||||
|
for true {
|
||||||
|
n, out := pr.stdout.Read(buf)
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
txt := string(buf[:n])
|
||||||
|
stdout += txt
|
||||||
|
_, _ = stdoutBufferWriter.Write(buf[:n])
|
||||||
|
for _, lstr := range listener {
|
||||||
|
lstr.ReadRawStdout(buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if out == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if out != nil {
|
||||||
|
errch <- out
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = stdoutBufferWriter.Close()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// [2] read raw stderr
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
stderrBufferReader, stderrBufferWriter := io.Pipe()
|
||||||
|
stderr := ""
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 128)
|
||||||
|
for true {
|
||||||
|
n, err := pr.stderr.Read(buf)
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
txt := string(buf[:n])
|
||||||
|
stderr += txt
|
||||||
|
_, _ = stderrBufferWriter.Write(buf[:n])
|
||||||
|
for _, lstr := range listener {
|
||||||
|
lstr.ReadRawStderr(buf[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
errch <- err
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = stderrBufferWriter.Close()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
combch := make(chan combevt, 32)
|
||||||
|
|
||||||
|
// [3] collect stdout line-by-line
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
scanner := bufio.NewScanner(stdoutBufferReader)
|
||||||
|
for scanner.Scan() {
|
||||||
|
txt := scanner.Text()
|
||||||
|
for _, lstr := range listener {
|
||||||
|
lstr.ReadStdoutLine(txt)
|
||||||
|
}
|
||||||
|
combch <- combevt{txt, false}
|
||||||
|
}
|
||||||
|
combch <- combevt{"", true}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// [4] collect stderr line-by-line
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
scanner := bufio.NewScanner(stderrBufferReader)
|
||||||
|
for scanner.Scan() {
|
||||||
|
txt := scanner.Text()
|
||||||
|
for _, lstr := range listener {
|
||||||
|
lstr.ReadStderrLine(txt)
|
||||||
|
}
|
||||||
|
combch <- combevt{txt, false}
|
||||||
|
}
|
||||||
|
combch <- combevt{"", true}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// [5] combine stdcombined
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
stdcombined := ""
|
||||||
|
go func() {
|
||||||
|
stopctr := 0
|
||||||
|
for stopctr < 2 {
|
||||||
|
vvv := <-combch
|
||||||
|
if vvv.stop {
|
||||||
|
stopctr++
|
||||||
|
} else {
|
||||||
|
stdcombined += vvv.line + "\n" // this comes from bufio.Scanner and has no newlines...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait for all (5) goroutines to finish
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if err, ok := syncext.ReadNonBlocking(errch); ok {
|
||||||
|
return "", "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return stdout, stderr, stdcombined, nil
|
||||||
|
}
|
136
cryptext/aes.go
136
cryptext/aes.go
@@ -1,10 +1,13 @@
|
|||||||
package cryptext
|
package cryptext
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"crypto/sha256"
|
||||||
|
"encoding/base32"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"golang.org/x/crypto/scrypt"
|
"golang.org/x/crypto/scrypt"
|
||||||
"io"
|
"io"
|
||||||
@@ -12,35 +15,90 @@ import (
|
|||||||
|
|
||||||
// https://stackoverflow.com/a/18819040/1761622
|
// https://stackoverflow.com/a/18819040/1761622
|
||||||
|
|
||||||
func EncryptAESSimple(password, text []byte) ([]byte, error) {
|
type aesPayload struct {
|
||||||
|
Salt []byte `json:"s"`
|
||||||
key, err := scrypt.Key(password, nil, 32768, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing
|
IV []byte `json:"i"`
|
||||||
if err != nil {
|
Data []byte `json:"d"`
|
||||||
return nil, err
|
Rounds int `json:"r"`
|
||||||
}
|
Version uint `json:"v"`
|
||||||
|
|
||||||
block, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b := base64.StdEncoding.EncodeToString(text)
|
|
||||||
ciphertext := make([]byte, aes.BlockSize+len(b))
|
|
||||||
|
|
||||||
iv := ciphertext[:aes.BlockSize]
|
|
||||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cfb := cipher.NewCFBEncrypter(block, iv)
|
|
||||||
cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(b))
|
|
||||||
|
|
||||||
return ciphertext, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecryptAESSimple(password, text []byte) ([]byte, error) {
|
func EncryptAESSimple(password []byte, data []byte, rounds int) (string, error) {
|
||||||
|
|
||||||
key, err := scrypt.Key(password, nil, 32768, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing
|
salt := make([]byte, 8)
|
||||||
|
_, err := io.ReadFull(rand.Reader, salt)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := scrypt.Key(password, salt, rounds, 8, 1, 32)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write(data)
|
||||||
|
checksum := h.Sum(nil)
|
||||||
|
if len(checksum) != 32 {
|
||||||
|
return "", errors.New("wrong cs size")
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext := make([]byte, 32+len(data))
|
||||||
|
|
||||||
|
iv := make([]byte, aes.BlockSize)
|
||||||
|
_, err = io.ReadFull(rand.Reader, iv)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedData := make([]byte, 0, 32+len(data))
|
||||||
|
combinedData = append(combinedData, checksum...)
|
||||||
|
combinedData = append(combinedData, data...)
|
||||||
|
|
||||||
|
cfb := cipher.NewCFBEncrypter(block, iv)
|
||||||
|
cfb.XORKeyStream(ciphertext, combinedData)
|
||||||
|
|
||||||
|
pl := aesPayload{
|
||||||
|
Salt: salt,
|
||||||
|
IV: iv,
|
||||||
|
Data: ciphertext,
|
||||||
|
Version: 1,
|
||||||
|
Rounds: rounds,
|
||||||
|
}
|
||||||
|
|
||||||
|
jbin, err := json.Marshal(pl)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
res := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(jbin)
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecryptAESSimple(password []byte, encText string) ([]byte, error) {
|
||||||
|
|
||||||
|
jbin, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(encText)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pl aesPayload
|
||||||
|
err = json.Unmarshal(jbin, &pl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pl.Version != 1 {
|
||||||
|
return nil, errors.New("unsupported version")
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := scrypt.Key(password, pl.Salt, pl.Rounds, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -50,18 +108,24 @@ func DecryptAESSimple(password, text []byte) ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(text) < aes.BlockSize {
|
dest := make([]byte, len(pl.Data))
|
||||||
return nil, errors.New("ciphertext too short")
|
|
||||||
|
cfb := cipher.NewCFBDecrypter(block, pl.IV)
|
||||||
|
cfb.XORKeyStream(dest, pl.Data)
|
||||||
|
|
||||||
|
if len(dest) < 32 {
|
||||||
|
return nil, errors.New("payload too small")
|
||||||
}
|
}
|
||||||
|
|
||||||
iv := text[:aes.BlockSize]
|
chck := dest[:32]
|
||||||
text = text[aes.BlockSize:]
|
data := dest[32:]
|
||||||
cfb := cipher.NewCFBDecrypter(block, iv)
|
|
||||||
cfb.XORKeyStream(text, text)
|
|
||||||
|
|
||||||
data, err := base64.StdEncoding.DecodeString(string(text))
|
h := sha256.New()
|
||||||
if err != nil {
|
h.Write(data)
|
||||||
return nil, err
|
chck2 := h.Sum(nil)
|
||||||
|
|
||||||
|
if !bytes.Equal(chck, chck2) {
|
||||||
|
return nil, errors.New("checksum mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
|
@@ -1,6 +1,9 @@
|
|||||||
package cryptext
|
package cryptext
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
func TestEncryptAESSimple(t *testing.T) {
|
func TestEncryptAESSimple(t *testing.T) {
|
||||||
|
|
||||||
@@ -8,15 +11,25 @@ func TestEncryptAESSimple(t *testing.T) {
|
|||||||
|
|
||||||
str1 := []byte("Hello World")
|
str1 := []byte("Hello World")
|
||||||
|
|
||||||
str2, err := EncryptAESSimple(pw, str1)
|
str2, err := EncryptAESSimple(pw, str1, 512)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%s\n", str2)
|
||||||
|
|
||||||
str3, err := DecryptAESSimple(pw, str2)
|
str3, err := DecryptAESSimple(pw, str2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEqual(t, string(str1), string(str3))
|
assertEqual(t, string(str1), string(str3))
|
||||||
|
|
||||||
|
str4, err := EncryptAESSimple(pw, str3, 512)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertNotEqual(t, string(str2), string(str4))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -23,3 +23,9 @@ func assertEqual(t *testing.T, actual string, expected string) {
|
|||||||
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
|
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func assertNotEqual(t *testing.T, actual string, expected string) {
|
||||||
|
if actual == expected {
|
||||||
|
t.Errorf("values do not differ: Actual: '%v', Expected: '%v'", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user