initial commit
[govpp.git] / vendor / github.com / onsi / gomega / internal / asyncassertion / async_assertion.go
1 package asyncassertion
2
3 import (
4         "errors"
5         "fmt"
6         "reflect"
7         "time"
8
9         "github.com/onsi/gomega/internal/oraclematcher"
10         "github.com/onsi/gomega/types"
11 )
12
13 type AsyncAssertionType uint
14
15 const (
16         AsyncAssertionTypeEventually AsyncAssertionType = iota
17         AsyncAssertionTypeConsistently
18 )
19
20 type AsyncAssertion struct {
21         asyncType       AsyncAssertionType
22         actualInput     interface{}
23         timeoutInterval time.Duration
24         pollingInterval time.Duration
25         fail            types.GomegaFailHandler
26         offset          int
27 }
28
29 func New(asyncType AsyncAssertionType, actualInput interface{}, fail types.GomegaFailHandler, timeoutInterval time.Duration, pollingInterval time.Duration, offset int) *AsyncAssertion {
30         actualType := reflect.TypeOf(actualInput)
31         if actualType.Kind() == reflect.Func {
32                 if actualType.NumIn() != 0 || actualType.NumOut() == 0 {
33                         panic("Expected a function with no arguments and one or more return values.")
34                 }
35         }
36
37         return &AsyncAssertion{
38                 asyncType:       asyncType,
39                 actualInput:     actualInput,
40                 fail:            fail,
41                 timeoutInterval: timeoutInterval,
42                 pollingInterval: pollingInterval,
43                 offset:          offset,
44         }
45 }
46
47 func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
48         return assertion.match(matcher, true, optionalDescription...)
49 }
50
51 func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool {
52         return assertion.match(matcher, false, optionalDescription...)
53 }
54
55 func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interface{}) string {
56         switch len(optionalDescription) {
57         case 0:
58                 return ""
59         default:
60                 return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n"
61         }
62 }
63
64 func (assertion *AsyncAssertion) actualInputIsAFunction() bool {
65         actualType := reflect.TypeOf(assertion.actualInput)
66         return actualType.Kind() == reflect.Func && actualType.NumIn() == 0 && actualType.NumOut() > 0
67 }
68
69 func (assertion *AsyncAssertion) pollActual() (interface{}, error) {
70         if assertion.actualInputIsAFunction() {
71                 values := reflect.ValueOf(assertion.actualInput).Call([]reflect.Value{})
72
73                 extras := []interface{}{}
74                 for _, value := range values[1:] {
75                         extras = append(extras, value.Interface())
76                 }
77
78                 success, message := vetExtras(extras)
79
80                 if !success {
81                         return nil, errors.New(message)
82                 }
83
84                 return values[0].Interface(), nil
85         }
86
87         return assertion.actualInput, nil
88 }
89
90 func (assertion *AsyncAssertion) matcherMayChange(matcher types.GomegaMatcher, value interface{}) bool {
91         if assertion.actualInputIsAFunction() {
92                 return true
93         }
94
95         return oraclematcher.MatchMayChangeInTheFuture(matcher, value)
96 }
97
98 func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...interface{}) bool {
99         timer := time.Now()
100         timeout := time.After(assertion.timeoutInterval)
101
102         description := assertion.buildDescription(optionalDescription...)
103
104         var matches bool
105         var err error
106         mayChange := true
107         value, err := assertion.pollActual()
108         if err == nil {
109                 mayChange = assertion.matcherMayChange(matcher, value)
110                 matches, err = matcher.Match(value)
111         }
112
113         fail := func(preamble string) {
114                 errMsg := ""
115                 message := ""
116                 if err != nil {
117                         errMsg = "Error: " + err.Error()
118                 } else {
119                         if desiredMatch {
120                                 message = matcher.FailureMessage(value)
121                         } else {
122                                 message = matcher.NegatedFailureMessage(value)
123                         }
124                 }
125                 assertion.fail(fmt.Sprintf("%s after %.3fs.\n%s%s%s", preamble, time.Since(timer).Seconds(), description, message, errMsg), 3+assertion.offset)
126         }
127
128         if assertion.asyncType == AsyncAssertionTypeEventually {
129                 for {
130                         if err == nil && matches == desiredMatch {
131                                 return true
132                         }
133
134                         if !mayChange {
135                                 fail("No future change is possible.  Bailing out early")
136                                 return false
137                         }
138
139                         select {
140                         case <-time.After(assertion.pollingInterval):
141                                 value, err = assertion.pollActual()
142                                 if err == nil {
143                                         mayChange = assertion.matcherMayChange(matcher, value)
144                                         matches, err = matcher.Match(value)
145                                 }
146                         case <-timeout:
147                                 fail("Timed out")
148                                 return false
149                         }
150                 }
151         } else if assertion.asyncType == AsyncAssertionTypeConsistently {
152                 for {
153                         if !(err == nil && matches == desiredMatch) {
154                                 fail("Failed")
155                                 return false
156                         }
157
158                         if !mayChange {
159                                 return true
160                         }
161
162                         select {
163                         case <-time.After(assertion.pollingInterval):
164                                 value, err = assertion.pollActual()
165                                 if err == nil {
166                                         mayChange = assertion.matcherMayChange(matcher, value)
167                                         matches, err = matcher.Match(value)
168                                 }
169                         case <-timeout:
170                                 return true
171                         }
172                 }
173         }
174
175         return false
176 }
177
178 func vetExtras(extras []interface{}) (bool, string) {
179         for i, extra := range extras {
180                 if extra != nil {
181                         zeroValue := reflect.Zero(reflect.TypeOf(extra)).Interface()
182                         if !reflect.DeepEqual(zeroValue, extra) {
183                                 message := fmt.Sprintf("Unexpected non-nil/non-zero extra argument at index %d:\n\t<%T>: %#v", i+1, extra, extra)
184                                 return false, message
185                         }
186                 }
187         }
188         return true, ""
189 }