diff --git a/src/components/connect.js b/src/components/connect.js
index 48a8b4ac9..475c2e714 100644
--- a/src/components/connect.js
+++ b/src/components/connect.js
@@ -1,4 +1,4 @@
-import React, { Component } from 'react'
+import React, { Component, createElement } from 'react'
import storeShape from '../utils/storeShape'
import shallowEqual from '../utils/shallowEqual'
import isPlainObject from '../utils/isPlainObject'
@@ -28,8 +28,8 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
wrapActionCreators(mapDispatchToProps) :
mapDispatchToProps || defaultMapDispatchToProps
const finalMergeProps = mergeProps || defaultMergeProps
- const shouldUpdateStateProps = finalMapStateToProps.length !== 1
- const shouldUpdateDispatchProps = finalMapDispatchToProps.length !== 1
+ const doStatePropsDependOnOwnProps = finalMapStateToProps.length !== 1
+ const doDispatchPropsDependOnOwnProps = finalMapDispatchToProps.length !== 1
const { pure = true, withRef = false } = options
// Helps track hot reloading.
@@ -37,7 +37,7 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
function computeStateProps(store, props) {
const state = store.getState()
- const stateProps = shouldUpdateStateProps ?
+ const stateProps = doStatePropsDependOnOwnProps ?
finalMapStateToProps(state, props) :
finalMapStateToProps(state)
@@ -51,7 +51,7 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
function computeDispatchProps(store, props) {
const { dispatch } = store
- const dispatchProps = shouldUpdateDispatchProps ?
+ const dispatchProps = doDispatchPropsDependOnOwnProps ?
finalMapDispatchToProps(dispatch, props) :
finalMapDispatchToProps(dispatch)
@@ -63,7 +63,7 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
return dispatchProps
}
- function computeNextState(stateProps, dispatchProps, parentProps) {
+ function computeMergedProps(stateProps, dispatchProps, parentProps) {
const mergedProps = finalMergeProps(stateProps, dispatchProps, parentProps)
invariant(
isPlainObject(mergedProps),
@@ -75,33 +75,8 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
return function wrapWithConnect(WrappedComponent) {
class Connect extends Component {
- shouldComponentUpdate(nextProps, nextState) {
- if (!pure) {
- this.updateStateProps(nextProps)
- this.updateDispatchProps(nextProps)
- this.updateState(nextProps)
- return true
- }
-
- const storeChanged = nextState.storeState !== this.state.storeState
- const propsChanged = !shallowEqual(nextProps, this.props)
- let mapStateProducedChange = false
- let dispatchPropsChanged = false
-
- if (storeChanged || (propsChanged && shouldUpdateStateProps)) {
- mapStateProducedChange = this.updateStateProps(nextProps)
- }
-
- if (propsChanged && shouldUpdateDispatchProps) {
- dispatchPropsChanged = this.updateDispatchProps(nextProps)
- }
-
- if (propsChanged || mapStateProducedChange || dispatchPropsChanged) {
- this.updateState(nextProps)
- return true
- }
-
- return false
+ shouldComponentUpdate() {
+ return !pure || this.haveOwnPropsChanged || this.hasStoreStateChanged
}
constructor(props, context) {
@@ -116,23 +91,14 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
`or explicitly pass "store" as a prop to "${this.constructor.displayName}".`
)
- this.stateProps = computeStateProps(this.store, props)
- this.dispatchProps = computeDispatchProps(this.store, props)
- this.state = { storeState: this.store.getState() }
- this.updateState()
+ const storeState = this.store.getState()
+ this.state = { storeState }
+ this.clearCache()
}
- computeNextState(props = this.props) {
- return computeNextState(
- this.stateProps,
- this.dispatchProps,
- props
- )
- }
-
- updateStateProps(props = this.props) {
- const nextStateProps = computeStateProps(this.store, props)
- if (shallowEqual(nextStateProps, this.stateProps)) {
+ updateStatePropsIfNeeded() {
+ const nextStateProps = computeStateProps(this.store, this.props)
+ if (this.stateProps && shallowEqual(nextStateProps, this.stateProps)) {
return false
}
@@ -140,9 +106,9 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
return true
}
- updateDispatchProps(props = this.props) {
- const nextDispatchProps = computeDispatchProps(this.store, props)
- if (shallowEqual(nextDispatchProps, this.dispatchProps)) {
+ updateDispatchPropsIfNeeded() {
+ const nextDispatchProps = computeDispatchProps(this.store, this.props)
+ if (this.dispatchProps && shallowEqual(nextDispatchProps, this.dispatchProps)) {
return false
}
@@ -150,8 +116,12 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
return true
}
- updateState(props = this.props) {
- this.nextState = this.computeNextState(props)
+ updateMergedProps() {
+ this.mergedProps = computeMergedProps(
+ this.stateProps,
+ this.dispatchProps,
+ this.props
+ )
}
isSubscribed() {
@@ -176,8 +146,24 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
this.trySubscribe()
}
+ componentWillReceiveProps(nextProps) {
+ if (!pure || !shallowEqual(nextProps, this.props)) {
+ this.haveOwnPropsChanged = true
+ }
+ }
+
componentWillUnmount() {
this.tryUnsubscribe()
+ this.clearCache()
+ }
+
+ clearCache() {
+ this.dispatchProps = null
+ this.stateProps = null
+ this.mergedProps = null
+ this.haveOwnPropsChanged = true
+ this.hasStoreStateChanged = true
+ this.renderedElement = null
}
handleChange() {
@@ -185,9 +171,13 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
return
}
- this.setState({
- storeState: this.store.getState()
- })
+ const prevStoreState = this.state.storeState
+ const storeState = this.store.getState()
+
+ if (!pure || prevStoreState !== storeState) {
+ this.hasStoreStateChanged = true
+ this.setState({ storeState })
+ }
}
getWrappedInstance() {
@@ -200,10 +190,61 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
}
render() {
- const ref = withRef ? 'wrappedInstance' : null
- return (
-
- )
+ const {
+ haveOwnPropsChanged,
+ hasStoreStateChanged,
+ renderedElement
+ } = this
+
+ this.haveOwnPropsChanged = false
+ this.hasStoreStateChanged = false
+
+ let shouldUpdateStateProps = true
+ let shouldUpdateDispatchProps = true
+ if (pure && renderedElement) {
+ shouldUpdateStateProps = hasStoreStateChanged || (
+ haveOwnPropsChanged && doStatePropsDependOnOwnProps
+ )
+ shouldUpdateDispatchProps =
+ haveOwnPropsChanged && doDispatchPropsDependOnOwnProps
+ }
+
+ let haveStatePropsChanged = false
+ let haveDispatchPropsChanged = false
+ if (shouldUpdateStateProps) {
+ haveStatePropsChanged = this.updateStatePropsIfNeeded()
+ }
+ if (shouldUpdateDispatchProps) {
+ haveDispatchPropsChanged = this.updateDispatchPropsIfNeeded()
+ }
+
+ let haveMergedPropsChanged = true
+ if (
+ haveStatePropsChanged ||
+ haveDispatchPropsChanged ||
+ haveOwnPropsChanged
+ ) {
+ this.updateMergedProps()
+ } else {
+ haveMergedPropsChanged = false
+ }
+
+ if (!haveMergedPropsChanged && renderedElement) {
+ return renderedElement
+ }
+
+ if (withRef) {
+ this.renderedElement = createElement(WrappedComponent, {
+ ...this.mergedProps,
+ ref: 'wrappedInstance'
+ })
+ } else {
+ this.renderedElement = createElement(WrappedComponent,
+ this.mergedProps
+ )
+ }
+
+ return this.renderedElement
}
}
@@ -224,12 +265,8 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps,
// We are hot reloading!
this.version = version
-
- // Update the state and bindings.
this.trySubscribe()
- this.updateStateProps()
- this.updateDispatchProps()
- this.updateState()
+ this.clearCache()
}
}
diff --git a/test/components/connect.spec.js b/test/components/connect.spec.js
index 5b3444211..b9ac9b809 100644
--- a/test/components/connect.spec.js
+++ b/test/components/connect.spec.js
@@ -201,6 +201,44 @@ describe('React', () => {
expect(stub.props.pass).toEqual('through')
})
+ it('should handle unexpected prop changes with forceUpdate()', () => {
+ const store = createStore(() => ({}))
+
+ @connect(state => state)
+ class ConnectContainer extends Component {
+ render() {
+ return (
+
+ )
+ }
+ }
+
+ class Container extends Component {
+ constructor() {
+ super()
+ this.bar = 'baz'
+ }
+
+ componentDidMount() {
+ this.bar = 'foo'
+ this.forceUpdate()
+ this.c.forceUpdate()
+ }
+
+ render() {
+ return (
+
+ this.c = c} />
+
+ )
+ }
+ }
+
+ const container = TestUtils.renderIntoDocument()
+ const stub = TestUtils.findRenderedComponentWithType(container, Passthrough)
+ expect(stub.props.bar).toEqual('foo')
+ })
+
it('should remove undefined props', () => {
const store = createStore(() => ({}))
let props = { x: true }
@@ -323,7 +361,7 @@ describe('React', () => {
return (
-
+
)
}
}